From 12bc2efa5f319310759f9b183974e8e49f7845a0 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 23 Mar 2025 18:05:38 +0800 Subject: [PATCH 0001/1089] cherry pick --- python/sglang/srt/managers/io_struct.py | 10 ++++++++++ python/sglang/srt/managers/scheduler.py | 17 ++++++++++++++++- python/sglang/srt/managers/tokenizer_manager.py | 16 +++++++++++++++- 3 files changed, 41 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 0e1d5016524..06e5d0cff0f 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -738,3 +738,13 @@ class RpcReqInput: class RpcReqOutput: success: bool message: str + + +class BlockReqType(Enum): + BLOCK = 1 + UNBLOCK = 2 + + +@dataclass +class BlockReqInput: + type: BlockReqType diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 71a7e2c3a80..635a91d6cde 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -55,6 +55,8 @@ from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.io_struct import ( AbortReq, + BlockReqInput, + BlockReqType, CloseSessionReqInput, FlushCacheReq, GetInternalStateReq, @@ -359,6 +361,7 @@ def __init__( self.init_new_token_ratio - self.min_new_token_ratio ) / global_config.default_new_token_ratio_decay_steps self.new_token_ratio = self.init_new_token_ratio + self._blocked = False # Init watchdog thread self.watchdog_timeout = server_args.watchdog_timeout @@ -403,6 +406,7 @@ def __init__( (GetInternalStateReq, self.get_internal_state), (SetInternalStateReq, self.set_internal_state), (RpcReqInput, self.handle_rpc_request), + (BlockReqInput, self.handle_block_request), ] ) @@ -584,6 +588,8 @@ def event_loop_normal(self): while True: recv_reqs = self.recv_requests() self.process_input_requests(recv_reqs) + if self._blocked: + continue batch = self.get_next_batch_to_run() self.cur_batch = batch @@ -606,6 +612,8 @@ def event_loop_overlap(self): while True: recv_reqs = self.recv_requests() self.process_input_requests(recv_reqs) + if self._blocked: + continue batch = self.get_next_batch_to_run() self.cur_batch = batch @@ -1694,6 +1702,14 @@ def handle_rpc_request(self, recv_req: RpcReqInput): barrier() return RpcReqOutput(success, "" if not exec else str(exec)) + def handle_block_request(self, recv_req: BlockReqInput): + if recv_req.type == BlockReqType.BLOCK: + self._blocked = True + elif recv_req.type == BlockReqType.UNBLOCK: + self._blocked = False + else: + raise NotImplementedError(f"{recv_req=}") + def save_remote_model(self, params): url = params["url"] @@ -1942,7 +1958,6 @@ def run_scheduler_process( dp_rank: Optional[int], pipe_writer, ): - # Generate the prefix if dp_rank is None: prefix = f" TP{tp_rank}" diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 04082ab5891..011c6a59bd8 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -16,7 +16,6 @@ import asyncio import copy import dataclasses -import json import logging import os import pickle @@ -62,6 +61,8 @@ BatchMultimodalOut, BatchStrOut, BatchTokenIDOut, + BlockReqInput, + BlockReqType, CloseSessionReqInput, ConfigureLoggingReq, EmbeddingReqInput, @@ -98,6 +99,7 @@ from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import ( dataclass_to_string_truncated, + get_bool_env_var, get_zmq_socket, kill_process_tree, ) @@ -470,6 +472,9 @@ def _send_one_request( self.rid_to_state[obj.rid] = state self.send_to_scheduler.send_pyobj(tokenized_obj) + def _send_block_request(self, type: BlockReqType): + self.send_to_scheduler.send_pyobj(BlockReqInput(type)) + async def _wait_one_response( self, obj: Union[GenerateReqInput, EmbeddingReqInput], @@ -539,12 +544,16 @@ async def _handle_batch_request( rids = [] if getattr(obj, "parallel_sample_num", 1) == 1: # Send all requests + if _ENABLE_COLOCATED_BATCH_GEN: + self._send_block_request(BlockReqType.BLOCK) for i in range(batch_size): tmp_obj = obj[i] tokenized_obj = await self._tokenize_one_request(tmp_obj) self._send_one_request(tmp_obj, tokenized_obj, created_time) generators.append(self._wait_one_response(tmp_obj, request)) rids.append(tmp_obj.rid) + if _ENABLE_COLOCATED_BATCH_GEN: + self._send_block_request(BlockReqType.UNBLOCK) else: # FIXME: When using batch and parallel_sample_num together, the perf is not optimal. if batch_size > 128: @@ -1191,3 +1200,8 @@ def handle_recv(self, recv_obj: T): self._result_values.append(recv_obj) if len(self._result_values) == self._fan_out: self._result_event.set() + + +_ENABLE_COLOCATED_BATCH_GEN = get_bool_env_var( + "SGLANG_ENABLE_COLOCATED_BATCH_GEN", "false" +) From 3dbf4333e8e33e3d68baec4412e17fc1d0ed06b0 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 23 Mar 2025 18:07:17 +0800 Subject: [PATCH 0002/1089] more --- .../sglang/srt/managers/tokenizer_manager.py | 16 ++++------ python/sglang/srt/utils.py | 29 ++++++++++--------- 2 files changed, 21 insertions(+), 24 deletions(-) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 011c6a59bd8..7043967763c 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -45,7 +45,6 @@ import zmq import zmq.asyncio from fastapi import BackgroundTasks - from sglang.srt.aio_rwlock import RWLock from sglang.srt.configs.model_config import ModelConfig from sglang.srt.disaggregation.conn import KVBootstrapServer @@ -101,7 +100,7 @@ dataclass_to_string_truncated, get_bool_env_var, get_zmq_socket, - kill_process_tree, + kill_process_tree, ENABLE_COLOCATED_BATCH_GEN, ) from sglang.utils import TypeBasedDispatcher, get_exception_traceback @@ -544,7 +543,7 @@ async def _handle_batch_request( rids = [] if getattr(obj, "parallel_sample_num", 1) == 1: # Send all requests - if _ENABLE_COLOCATED_BATCH_GEN: + if ENABLE_COLOCATED_BATCH_GEN: self._send_block_request(BlockReqType.BLOCK) for i in range(batch_size): tmp_obj = obj[i] @@ -552,7 +551,7 @@ async def _handle_batch_request( self._send_one_request(tmp_obj, tokenized_obj, created_time) generators.append(self._wait_one_response(tmp_obj, request)) rids.append(tmp_obj.rid) - if _ENABLE_COLOCATED_BATCH_GEN: + if ENABLE_COLOCATED_BATCH_GEN: self._send_block_request(BlockReqType.UNBLOCK) else: # FIXME: When using batch and parallel_sample_num together, the perf is not optimal. @@ -948,8 +947,8 @@ def _handle_batch_output( elif isinstance(recv_obj, BatchTokenIDOut): if self.server_args.stream_output and state.obj.stream: output_token_ids = recv_obj.output_ids[i][ - state.last_output_offset : - ] + state.last_output_offset: + ] state.last_output_offset = len(recv_obj.output_ids[i]) else: output_token_ids = recv_obj.output_ids[i] @@ -1200,8 +1199,3 @@ def handle_recv(self, recv_obj: T): self._result_values.append(recv_obj) if len(self._result_values) == self._fan_out: self._result_event.set() - - -_ENABLE_COLOCATED_BATCH_GEN = get_bool_env_var( - "SGLANG_ENABLE_COLOCATED_BATCH_GEN", "false" -) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index af2907f4154..ab5233572a1 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -55,14 +55,12 @@ import zmq from fastapi.responses import ORJSONResponse from packaging import version as pkg_version -from packaging.version import Version, parse from starlette.routing import Mount from torch import nn from torch.func import functional_call from torch.library import Library from torch.profiler import ProfilerActivity, profile, record_function from torch.utils._contextlib import _DecoratorContextManager -from torch.utils.cpp_extension import CUDA_HOME from triton.runtime.cache import ( FileCacheManager, default_cache_dir, @@ -190,6 +188,11 @@ def clone(self) -> "DynamicGradMode": return self.__class__() +ENABLE_COLOCATED_BATCH_GEN = get_bool_env_var( + "SGLANG_ENABLE_COLOCATED_BATCH_GEN", "false" +) + + def enable_show_time_cost(): global show_time_cost show_time_cost = True @@ -852,10 +855,10 @@ def get_zmq_socket( context: zmq.Context, socket_type: zmq.SocketType, endpoint: str, bind: bool ): mem = psutil.virtual_memory() - total_mem = mem.total / 1024**3 - available_mem = mem.available / 1024**3 + total_mem = mem.total / 1024 ** 3 + available_mem = mem.available / 1024 ** 3 if total_mem > 32 and available_mem > 16: - buf_size = int(0.5 * 1024**3) + buf_size = int(0.5 * 1024 ** 3) else: buf_size = -1 @@ -1416,10 +1419,10 @@ def dataclass_to_string_truncated( return ( "{" + ", ".join( - f"'{k}': {dataclass_to_string_truncated(v, max_length)}" - for k, v in data.items() - if k not in skip_names - ) + f"'{k}': {dataclass_to_string_truncated(v, max_length)}" + for k, v in data.items() + if k not in skip_names + ) + "}" ) elif dataclasses.is_dataclass(data): @@ -1427,10 +1430,10 @@ def dataclass_to_string_truncated( return ( f"{data.__class__.__name__}(" + ", ".join( - f"{f.name}={dataclass_to_string_truncated(getattr(data, f.name), max_length)}" - for f in fields - if f.name not in skip_names - ) + f"{f.name}={dataclass_to_string_truncated(getattr(data, f.name), max_length)}" + for f in fields + if f.name not in skip_names + ) + ")" ) else: From 8e7fd002791dc80bc55967d8ce810e2b2a656879 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 23 Mar 2025 18:07:55 +0800 Subject: [PATCH 0003/1089] Revert "cherry pick" This reverts commit 12bc2efa5f319310759f9b183974e8e49f7845a0. --- python/sglang/srt/managers/scheduler.py | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 635a91d6cde..71a7e2c3a80 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -55,8 +55,6 @@ from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.io_struct import ( AbortReq, - BlockReqInput, - BlockReqType, CloseSessionReqInput, FlushCacheReq, GetInternalStateReq, @@ -361,7 +359,6 @@ def __init__( self.init_new_token_ratio - self.min_new_token_ratio ) / global_config.default_new_token_ratio_decay_steps self.new_token_ratio = self.init_new_token_ratio - self._blocked = False # Init watchdog thread self.watchdog_timeout = server_args.watchdog_timeout @@ -406,7 +403,6 @@ def __init__( (GetInternalStateReq, self.get_internal_state), (SetInternalStateReq, self.set_internal_state), (RpcReqInput, self.handle_rpc_request), - (BlockReqInput, self.handle_block_request), ] ) @@ -588,8 +584,6 @@ def event_loop_normal(self): while True: recv_reqs = self.recv_requests() self.process_input_requests(recv_reqs) - if self._blocked: - continue batch = self.get_next_batch_to_run() self.cur_batch = batch @@ -612,8 +606,6 @@ def event_loop_overlap(self): while True: recv_reqs = self.recv_requests() self.process_input_requests(recv_reqs) - if self._blocked: - continue batch = self.get_next_batch_to_run() self.cur_batch = batch @@ -1702,14 +1694,6 @@ def handle_rpc_request(self, recv_req: RpcReqInput): barrier() return RpcReqOutput(success, "" if not exec else str(exec)) - def handle_block_request(self, recv_req: BlockReqInput): - if recv_req.type == BlockReqType.BLOCK: - self._blocked = True - elif recv_req.type == BlockReqType.UNBLOCK: - self._blocked = False - else: - raise NotImplementedError(f"{recv_req=}") - def save_remote_model(self, params): url = params["url"] @@ -1958,6 +1942,7 @@ def run_scheduler_process( dp_rank: Optional[int], pipe_writer, ): + # Generate the prefix if dp_rank is None: prefix = f" TP{tp_rank}" From 25763ed1c5ebe7af09fe2085edbbf6500e7d585a Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 23 Mar 2025 18:08:21 +0800 Subject: [PATCH 0004/1089] more --- .../srt/managers/scheduler_input_blocker.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) create mode 100644 python/sglang/srt/managers/scheduler_input_blocker.py diff --git a/python/sglang/srt/managers/scheduler_input_blocker.py b/python/sglang/srt/managers/scheduler_input_blocker.py new file mode 100644 index 00000000000..5f4fb333113 --- /dev/null +++ b/python/sglang/srt/managers/scheduler_input_blocker.py @@ -0,0 +1,16 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +class SchedulerInputBlocker: + pass From c4e2562fc5985f16f9405f58efe9ca842a2a9645 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 23 Mar 2025 18:09:00 +0800 Subject: [PATCH 0005/1089] more --- python/sglang/srt/managers/scheduler.py | 28 +++++++++++++------------ 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 71a7e2c3a80..2adf5c04683 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -98,6 +98,7 @@ PrefillAdder, SchedulePolicy, ) +from sglang.srt.managers.scheduler_input_blocker import SchedulerInputBlocker from sglang.srt.managers.scheduler_output_processor_mixin import ( SchedulerOutputProcessorMixin, ) @@ -124,7 +125,7 @@ pyspy_dump_schedulers, set_gpu_proc_affinity, set_random_seed, - suppress_other_loggers, + suppress_other_loggers, ENABLE_COLOCATED_BATCH_GEN, ) from sglang.utils import TypeBasedDispatcher, get_exception_traceback @@ -356,8 +357,8 @@ def __init__( 1.0, ) self.new_token_ratio_decay = ( - self.init_new_token_ratio - self.min_new_token_ratio - ) / global_config.default_new_token_ratio_decay_steps + self.init_new_token_ratio - self.min_new_token_ratio + ) / global_config.default_new_token_ratio_decay_steps self.new_token_ratio = self.init_new_token_ratio # Init watchdog thread @@ -371,6 +372,8 @@ def __init__( enable=server_args.enable_memory_saver ) + self.input_blocker = SchedulerInputBlocker() if ENABLE_COLOCATED_BATCH_GEN else None + # Init profiler self.torch_profiler = None self.torch_profiler_output_dir: Optional[str] = None @@ -1232,10 +1235,10 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: if ( self.lora_paths and len( - lora_set - | set([req.lora_path for req in adder.can_run_list]) - | set([req.lora_path]) - ) + lora_set + | set([req.lora_path for req in adder.can_run_list]) + | set([req.lora_path]) + ) > self.max_loras_per_batch ): self.running_batch.batch_is_full = True @@ -1260,9 +1263,9 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: self.running_batch.batch_is_full = len( adder.can_run_list ) > 0 or ( - self.running_batch is not None - and not self.running_batch.is_empty() - ) + self.running_batch is not None + and not self.running_batch.is_empty() + ) else: self.running_batch.batch_is_full = True break @@ -1466,8 +1469,8 @@ def prepare_dp_attn_batch(self, local_batch: ScheduleBatch): # We should have at least 1 token for sample in every case. max(extend_len - logprob_start_len, 1) for logprob_start_len, extend_len in zip( - local_batch.extend_logprob_start_lens, local_batch.extend_lens - ) + local_batch.extend_logprob_start_lens, local_batch.extend_lens + ) ] ) @@ -1942,7 +1945,6 @@ def run_scheduler_process( dp_rank: Optional[int], pipe_writer, ): - # Generate the prefix if dp_rank is None: prefix = f" TP{tp_rank}" From 844b3be0e43acb124410c7a0702a665695855c22 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 23 Mar 2025 18:09:39 +0800 Subject: [PATCH 0006/1089] more --- python/sglang/srt/managers/scheduler.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 2adf5c04683..60867a422a3 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -726,6 +726,9 @@ def recv_requests(self) -> List[Req]: else: recv_reqs = None + if self.input_blocker is not None: + recv_reqs = self.input_blocker.handle(recv_reqs) + if self.server_args.enable_dp_attention: if self.attn_tp_rank == 0: work_reqs = [ From cf1d1dc6f6dc7db38a923b4fcf2bd241d82c1274 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 23 Mar 2025 18:09:58 +0800 Subject: [PATCH 0007/1089] more --- python/sglang/srt/managers/scheduler_input_blocker.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/scheduler_input_blocker.py b/python/sglang/srt/managers/scheduler_input_blocker.py index 5f4fb333113..fdb4a8c825d 100644 --- a/python/sglang/srt/managers/scheduler_input_blocker.py +++ b/python/sglang/srt/managers/scheduler_input_blocker.py @@ -11,6 +11,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +from typing import List + class SchedulerInputBlocker: - pass + def handle(self, recv_reqs: List): + TODO From 2bea87e8d0a0f2633502767271238547d7c11a5a Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 23 Mar 2025 18:10:22 +0800 Subject: [PATCH 0008/1089] more --- python/sglang/srt/managers/scheduler_input_blocker.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/scheduler_input_blocker.py b/python/sglang/srt/managers/scheduler_input_blocker.py index fdb4a8c825d..e4e8bc8b0cd 100644 --- a/python/sglang/srt/managers/scheduler_input_blocker.py +++ b/python/sglang/srt/managers/scheduler_input_blocker.py @@ -11,9 +11,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -from typing import List +from typing import List, Optional, Any class SchedulerInputBlocker: - def handle(self, recv_reqs: List): + def handle(self, recv_reqs: Optional[List[Any]]): TODO From 49e1a7d93ea8f1167a3ddfa1cd8b9fdab7c0f8cf Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 23 Mar 2025 18:11:32 +0800 Subject: [PATCH 0009/1089] more --- python/sglang/srt/managers/scheduler_input_blocker.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/python/sglang/srt/managers/scheduler_input_blocker.py b/python/sglang/srt/managers/scheduler_input_blocker.py index e4e8bc8b0cd..748552b8c9c 100644 --- a/python/sglang/srt/managers/scheduler_input_blocker.py +++ b/python/sglang/srt/managers/scheduler_input_blocker.py @@ -11,9 +11,19 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +from enum import Enum, auto from typing import List, Optional, Any class SchedulerInputBlocker: + def __init__(self): + self._state = _State.UNBLOCKED + def handle(self, recv_reqs: Optional[List[Any]]): TODO + + +class _State(Enum): + UNBLOCKED = auto() + BLOCKED = auto() + AWAITING_GLOBAL_UNBLOCK = auto() From 323885101369113e64a2c66d73971e8fa4a4d3e0 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 23 Mar 2025 18:12:17 +0800 Subject: [PATCH 0010/1089] more --- python/sglang/srt/managers/scheduler_input_blocker.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/python/sglang/srt/managers/scheduler_input_blocker.py b/python/sglang/srt/managers/scheduler_input_blocker.py index 748552b8c9c..8414e921d09 100644 --- a/python/sglang/srt/managers/scheduler_input_blocker.py +++ b/python/sglang/srt/managers/scheduler_input_blocker.py @@ -14,6 +14,8 @@ from enum import Enum, auto from typing import List, Optional, Any +from sglang.srt.managers.io_struct import BlockReqInput, BlockReqType + class SchedulerInputBlocker: def __init__(self): @@ -22,6 +24,14 @@ def __init__(self): def handle(self, recv_reqs: Optional[List[Any]]): TODO + def _execute_block_request(self, recv_req: BlockReqInput): + if recv_req.type == BlockReqType.BLOCK: + TODO + elif recv_req.type == BlockReqType.UNBLOCK: + TODO + else: + raise NotImplementedError(f"{recv_req=}") + class _State(Enum): UNBLOCKED = auto() From 768ef7318560ee51a1d7b3336c9b0357b2fc1137 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 23 Mar 2025 18:13:08 +0800 Subject: [PATCH 0011/1089] more --- .../srt/managers/scheduler_input_blocker.py | 24 +++++++++++++------ 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/managers/scheduler_input_blocker.py b/python/sglang/srt/managers/scheduler_input_blocker.py index 8414e921d09..65a173b2813 100644 --- a/python/sglang/srt/managers/scheduler_input_blocker.py +++ b/python/sglang/srt/managers/scheduler_input_blocker.py @@ -22,15 +22,25 @@ def __init__(self): self._state = _State.UNBLOCKED def handle(self, recv_reqs: Optional[List[Any]]): - TODO + for recv_req in recv_reqs: + self._handle_recv_req(recv_req) - def _execute_block_request(self, recv_req: BlockReqInput): - if recv_req.type == BlockReqType.BLOCK: - TODO - elif recv_req.type == BlockReqType.UNBLOCK: - TODO + def _handle_recv_req(self, recv_req): + if isinstance(recv_req, BlockReqInput): + if recv_req.type == BlockReqType.BLOCK: + self._execute_block_request() + elif recv_req.type == BlockReqType.UNBLOCK: + self._execute_unblock_request() + else: + raise NotImplementedError(f"{recv_req=}") else: - raise NotImplementedError(f"{recv_req=}") + TODO + + def _execute_block_request(self): + TODO + + def _execute_unblock_request(self): + TODO class _State(Enum): From 08f0936899c44aa290b16e82156430b1284aa030 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 23 Mar 2025 18:13:28 +0800 Subject: [PATCH 0012/1089] more --- python/sglang/srt/managers/scheduler_input_blocker.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/scheduler_input_blocker.py b/python/sglang/srt/managers/scheduler_input_blocker.py index 65a173b2813..0c8ad99ca12 100644 --- a/python/sglang/srt/managers/scheduler_input_blocker.py +++ b/python/sglang/srt/managers/scheduler_input_blocker.py @@ -22,19 +22,22 @@ def __init__(self): self._state = _State.UNBLOCKED def handle(self, recv_reqs: Optional[List[Any]]): + output_reqs = [] for recv_req in recv_reqs: - self._handle_recv_req(recv_req) + output_reqs += self._handle_recv_req(recv_req) def _handle_recv_req(self, recv_req): if isinstance(recv_req, BlockReqInput): if recv_req.type == BlockReqType.BLOCK: self._execute_block_request() + return [] elif recv_req.type == BlockReqType.UNBLOCK: self._execute_unblock_request() + return [] else: raise NotImplementedError(f"{recv_req=}") else: - TODO + return TODO def _execute_block_request(self): TODO From df77161075dc66850fcf1bd1a2c32326b0a4dd88 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 23 Mar 2025 18:13:40 +0800 Subject: [PATCH 0013/1089] more --- python/sglang/srt/managers/scheduler_input_blocker.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/managers/scheduler_input_blocker.py b/python/sglang/srt/managers/scheduler_input_blocker.py index 0c8ad99ca12..993a973e0b8 100644 --- a/python/sglang/srt/managers/scheduler_input_blocker.py +++ b/python/sglang/srt/managers/scheduler_input_blocker.py @@ -25,6 +25,7 @@ def handle(self, recv_reqs: Optional[List[Any]]): output_reqs = [] for recv_req in recv_reqs: output_reqs += self._handle_recv_req(recv_req) + return output_reqs def _handle_recv_req(self, recv_req): if isinstance(recv_req, BlockReqInput): From 9de9ac0bca6cb846a7dbfdd73c70a8523b33b802 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 23 Mar 2025 18:14:37 +0800 Subject: [PATCH 0014/1089] more --- python/sglang/srt/managers/scheduler_input_blocker.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/scheduler_input_blocker.py b/python/sglang/srt/managers/scheduler_input_blocker.py index 993a973e0b8..a876bcb08ee 100644 --- a/python/sglang/srt/managers/scheduler_input_blocker.py +++ b/python/sglang/srt/managers/scheduler_input_blocker.py @@ -20,6 +20,7 @@ class SchedulerInputBlocker: def __init__(self): self._state = _State.UNBLOCKED + self._pending_reqs = [] def handle(self, recv_reqs: Optional[List[Any]]): output_reqs = [] @@ -38,7 +39,11 @@ def _handle_recv_req(self, recv_req): else: raise NotImplementedError(f"{recv_req=}") else: - return TODO + if self._state == _State.UNBLOCKED: + return [recv_req] + else: + self._pending_reqs.append(recv_req) + return [] def _execute_block_request(self): TODO From a9ccda96ed090e90d505adefbfea2a3985f033ba Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 23 Mar 2025 18:14:48 +0800 Subject: [PATCH 0015/1089] more --- python/sglang/srt/managers/scheduler_input_blocker.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/managers/scheduler_input_blocker.py b/python/sglang/srt/managers/scheduler_input_blocker.py index a876bcb08ee..2a2979ccd5c 100644 --- a/python/sglang/srt/managers/scheduler_input_blocker.py +++ b/python/sglang/srt/managers/scheduler_input_blocker.py @@ -31,10 +31,10 @@ def handle(self, recv_reqs: Optional[List[Any]]): def _handle_recv_req(self, recv_req): if isinstance(recv_req, BlockReqInput): if recv_req.type == BlockReqType.BLOCK: - self._execute_block_request() + self._execute_block_req() return [] elif recv_req.type == BlockReqType.UNBLOCK: - self._execute_unblock_request() + self._execute_unblock_req() return [] else: raise NotImplementedError(f"{recv_req=}") @@ -45,10 +45,10 @@ def _handle_recv_req(self, recv_req): self._pending_reqs.append(recv_req) return [] - def _execute_block_request(self): + def _execute_block_req(self): TODO - def _execute_unblock_request(self): + def _execute_unblock_req(self): TODO From 955a45f91e1c7b13d1833e7aa6a99c0d6728b568 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 23 Mar 2025 18:15:51 +0800 Subject: [PATCH 0016/1089] more --- python/sglang/srt/managers/scheduler.py | 5 ++--- python/sglang/srt/managers/scheduler_input_blocker.py | 3 ++- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 60867a422a3..1f681ef7181 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -32,8 +32,6 @@ import setproctitle import torch import zmq -from torch.distributed import barrier - from sglang.global_config import global_config from sglang.srt.configs.model_config import ModelConfig from sglang.srt.constrained.base_grammar_backend import create_grammar_backend @@ -128,6 +126,7 @@ suppress_other_loggers, ENABLE_COLOCATED_BATCH_GEN, ) from sglang.utils import TypeBasedDispatcher, get_exception_traceback +from torch.distributed import barrier logger = logging.getLogger(__name__) @@ -372,7 +371,7 @@ def __init__( enable=server_args.enable_memory_saver ) - self.input_blocker = SchedulerInputBlocker() if ENABLE_COLOCATED_BATCH_GEN else None + self.input_blocker = SchedulerInputBlocker(noop=self.attn_tp_rank != 0) if ENABLE_COLOCATED_BATCH_GEN else None # Init profiler self.torch_profiler = None diff --git a/python/sglang/srt/managers/scheduler_input_blocker.py b/python/sglang/srt/managers/scheduler_input_blocker.py index 2a2979ccd5c..989d177fa48 100644 --- a/python/sglang/srt/managers/scheduler_input_blocker.py +++ b/python/sglang/srt/managers/scheduler_input_blocker.py @@ -18,9 +18,10 @@ class SchedulerInputBlocker: - def __init__(self): + def __init__(self, noop: bool): self._state = _State.UNBLOCKED self._pending_reqs = [] + self._noop = noop def handle(self, recv_reqs: Optional[List[Any]]): output_reqs = [] From 197575e9f1c2246e15ddda3a28de706c3b84035f Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 23 Mar 2025 18:16:24 +0800 Subject: [PATCH 0017/1089] more --- .../sglang/srt/managers/scheduler_input_blocker.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/managers/scheduler_input_blocker.py b/python/sglang/srt/managers/scheduler_input_blocker.py index 989d177fa48..b9b5b432ea8 100644 --- a/python/sglang/srt/managers/scheduler_input_blocker.py +++ b/python/sglang/srt/managers/scheduler_input_blocker.py @@ -24,10 +24,15 @@ def __init__(self, noop: bool): self._noop = noop def handle(self, recv_reqs: Optional[List[Any]]): - output_reqs = [] - for recv_req in recv_reqs: - output_reqs += self._handle_recv_req(recv_req) - return output_reqs + assert (recv_reqs is None) == self._noop + + if not self._noop: + output_reqs = [] + for recv_req in recv_reqs: + output_reqs += self._handle_recv_req(recv_req) + + if not self._noop: + return output_reqs def _handle_recv_req(self, recv_req): if isinstance(recv_req, BlockReqInput): From 2f29709968432cfe6245ac327bd8181965800725 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 23 Mar 2025 18:17:45 +0800 Subject: [PATCH 0018/1089] more --- python/sglang/srt/managers/scheduler_input_blocker.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/scheduler_input_blocker.py b/python/sglang/srt/managers/scheduler_input_blocker.py index b9b5b432ea8..e7641c5a216 100644 --- a/python/sglang/srt/managers/scheduler_input_blocker.py +++ b/python/sglang/srt/managers/scheduler_input_blocker.py @@ -52,10 +52,14 @@ def _handle_recv_req(self, recv_req): return [] def _execute_block_req(self): - TODO + self._change_state(original=_State.UNBLOCKED, target=_State.BLOCKED) def _execute_unblock_req(self): - TODO + self._change_state(original=_State.BLOCKED, target=_State.AWAITING_GLOBAL_UNBLOCK) + + def _change_state(self, original: "_State", target: "_State"): + assert self._state == original, f"{self._state=} {original=} {target=}" + self._state = target class _State(Enum): From 44bce3515d1b7ba93bfe69f72fff4af79d465a7a Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 23 Mar 2025 18:18:19 +0800 Subject: [PATCH 0019/1089] more --- python/sglang/srt/managers/scheduler_input_blocker.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/sglang/srt/managers/scheduler_input_blocker.py b/python/sglang/srt/managers/scheduler_input_blocker.py index e7641c5a216..a1c2e827ffa 100644 --- a/python/sglang/srt/managers/scheduler_input_blocker.py +++ b/python/sglang/srt/managers/scheduler_input_blocker.py @@ -31,6 +31,8 @@ def handle(self, recv_reqs: Optional[List[Any]]): for recv_req in recv_reqs: output_reqs += self._handle_recv_req(recv_req) + self._maybe_fulfill_awaiting_global_unblock() + if not self._noop: return output_reqs @@ -57,6 +59,9 @@ def _execute_block_req(self): def _execute_unblock_req(self): self._change_state(original=_State.BLOCKED, target=_State.AWAITING_GLOBAL_UNBLOCK) + def _maybe_fulfill_awaiting_global_unblock(self): + TODO + def _change_state(self, original: "_State", target: "_State"): assert self._state == original, f"{self._state=} {original=} {target=}" self._state = target From 904f6b3c1e2a0fceb2ce3ba3e181c50c4340337e Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 23 Mar 2025 18:19:16 +0800 Subject: [PATCH 0020/1089] more --- python/sglang/srt/managers/scheduler_input_blocker.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/python/sglang/srt/managers/scheduler_input_blocker.py b/python/sglang/srt/managers/scheduler_input_blocker.py index a1c2e827ffa..16dda176833 100644 --- a/python/sglang/srt/managers/scheduler_input_blocker.py +++ b/python/sglang/srt/managers/scheduler_input_blocker.py @@ -14,6 +14,7 @@ from enum import Enum, auto from typing import List, Optional, Any +import torch from sglang.srt.managers.io_struct import BlockReqInput, BlockReqType @@ -60,6 +61,14 @@ def _execute_unblock_req(self): self._change_state(original=_State.BLOCKED, target=_State.AWAITING_GLOBAL_UNBLOCK) def _maybe_fulfill_awaiting_global_unblock(self): + if self._noop: + local_fulfill = True + else: + local_fulfill = TODO + + global_fulfill = torch.distributed.all_reduce(torch.tensor(local_fulfill), + torch.distributed.ReduceOp.MIN).item() + TODO def _change_state(self, original: "_State", target: "_State"): From 83345d797526bf17a378cf0a9aaaffd49eb7cca4 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 23 Mar 2025 18:19:41 +0800 Subject: [PATCH 0021/1089] more --- python/sglang/srt/managers/scheduler_input_blocker.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/managers/scheduler_input_blocker.py b/python/sglang/srt/managers/scheduler_input_blocker.py index 16dda176833..c3ae5b0e7dc 100644 --- a/python/sglang/srt/managers/scheduler_input_blocker.py +++ b/python/sglang/srt/managers/scheduler_input_blocker.py @@ -64,10 +64,10 @@ def _maybe_fulfill_awaiting_global_unblock(self): if self._noop: local_fulfill = True else: - local_fulfill = TODO + local_fulfill = self._state == _State.AWAITING_GLOBAL_UNBLOCK - global_fulfill = torch.distributed.all_reduce(torch.tensor(local_fulfill), - torch.distributed.ReduceOp.MIN).item() + global_fulfill = torch.distributed.all_reduce( + torch.tensor(local_fulfill), torch.distributed.ReduceOp.MIN).item() TODO From 69194963d286707a793e5e89b7fd32a2f688fa28 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 23 Mar 2025 18:20:19 +0800 Subject: [PATCH 0022/1089] more --- python/sglang/srt/managers/scheduler_input_blocker.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/managers/scheduler_input_blocker.py b/python/sglang/srt/managers/scheduler_input_blocker.py index c3ae5b0e7dc..75e4ac5c740 100644 --- a/python/sglang/srt/managers/scheduler_input_blocker.py +++ b/python/sglang/srt/managers/scheduler_input_blocker.py @@ -32,7 +32,7 @@ def handle(self, recv_reqs: Optional[List[Any]]): for recv_req in recv_reqs: output_reqs += self._handle_recv_req(recv_req) - self._maybe_fulfill_awaiting_global_unblock() + self._maybe_fulfill_global_unblock_barrier() if not self._noop: return output_reqs @@ -58,13 +58,13 @@ def _execute_block_req(self): self._change_state(original=_State.UNBLOCKED, target=_State.BLOCKED) def _execute_unblock_req(self): - self._change_state(original=_State.BLOCKED, target=_State.AWAITING_GLOBAL_UNBLOCK) + self._change_state(original=_State.BLOCKED, target=_State.GLOBAL_UNBLOCK_BARRIER) - def _maybe_fulfill_awaiting_global_unblock(self): + def _maybe_fulfill_global_unblock_barrier(self): if self._noop: local_fulfill = True else: - local_fulfill = self._state == _State.AWAITING_GLOBAL_UNBLOCK + local_fulfill = self._state == _State.GLOBAL_UNBLOCK_BARRIER global_fulfill = torch.distributed.all_reduce( torch.tensor(local_fulfill), torch.distributed.ReduceOp.MIN).item() @@ -79,4 +79,4 @@ def _change_state(self, original: "_State", target: "_State"): class _State(Enum): UNBLOCKED = auto() BLOCKED = auto() - AWAITING_GLOBAL_UNBLOCK = auto() + GLOBAL_UNBLOCK_BARRIER = auto() From e751184ebfdcbf2702b2a7ef42a5cd3d4c7e8f7f Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 23 Mar 2025 18:20:44 +0800 Subject: [PATCH 0023/1089] more --- python/sglang/srt/managers/scheduler_input_blocker.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/managers/scheduler_input_blocker.py b/python/sglang/srt/managers/scheduler_input_blocker.py index 75e4ac5c740..aaa7a173161 100644 --- a/python/sglang/srt/managers/scheduler_input_blocker.py +++ b/python/sglang/srt/managers/scheduler_input_blocker.py @@ -62,14 +62,15 @@ def _execute_unblock_req(self): def _maybe_fulfill_global_unblock_barrier(self): if self._noop: - local_fulfill = True + local_arrived = True else: - local_fulfill = self._state == _State.GLOBAL_UNBLOCK_BARRIER + local_arrived = self._state == _State.GLOBAL_UNBLOCK_BARRIER - global_fulfill = torch.distributed.all_reduce( - torch.tensor(local_fulfill), torch.distributed.ReduceOp.MIN).item() + global_arrived = torch.distributed.all_reduce( + torch.tensor(local_arrived), torch.distributed.ReduceOp.MIN).item() - TODO + if global_arrived: + TODO def _change_state(self, original: "_State", target: "_State"): assert self._state == original, f"{self._state=} {original=} {target=}" From 5a8a8e22972b8a07ce0f23bcbca0f2607591b911 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 23 Mar 2025 18:21:05 +0800 Subject: [PATCH 0024/1089] more --- python/sglang/srt/managers/scheduler_input_blocker.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/scheduler_input_blocker.py b/python/sglang/srt/managers/scheduler_input_blocker.py index aaa7a173161..66eb6e75a2b 100644 --- a/python/sglang/srt/managers/scheduler_input_blocker.py +++ b/python/sglang/srt/managers/scheduler_input_blocker.py @@ -69,7 +69,8 @@ def _maybe_fulfill_global_unblock_barrier(self): global_arrived = torch.distributed.all_reduce( torch.tensor(local_arrived), torch.distributed.ReduceOp.MIN).item() - if global_arrived: + if self._state == _State.GLOBAL_UNBLOCK_BARRIER and global_arrived: + self._change_state(original=_State.GLOBAL_UNBLOCK_BARRIER, target=_State.UNBLOCKED) TODO def _change_state(self, original: "_State", target: "_State"): From 7daf0047ce4f169a15231be8fc52488dcb0c3631 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 23 Mar 2025 18:22:11 +0800 Subject: [PATCH 0025/1089] more --- .../sglang/srt/managers/scheduler_input_blocker.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/managers/scheduler_input_blocker.py b/python/sglang/srt/managers/scheduler_input_blocker.py index 66eb6e75a2b..1159b0f08ba 100644 --- a/python/sglang/srt/managers/scheduler_input_blocker.py +++ b/python/sglang/srt/managers/scheduler_input_blocker.py @@ -32,7 +32,9 @@ def handle(self, recv_reqs: Optional[List[Any]]): for recv_req in recv_reqs: output_reqs += self._handle_recv_req(recv_req) - self._maybe_fulfill_global_unblock_barrier() + global_arrived_unblock_barrier = self._compute_global_unblock_barrier() + if self._state == _State.GLOBAL_UNBLOCK_BARRIER and global_arrived_unblock_barrier: + self._handle_arrive_unblock_barrier() if not self._noop: return output_reqs @@ -60,18 +62,18 @@ def _execute_block_req(self): def _execute_unblock_req(self): self._change_state(original=_State.BLOCKED, target=_State.GLOBAL_UNBLOCK_BARRIER) - def _maybe_fulfill_global_unblock_barrier(self): + def _compute_global_unblock_barrier(self): if self._noop: local_arrived = True else: local_arrived = self._state == _State.GLOBAL_UNBLOCK_BARRIER - global_arrived = torch.distributed.all_reduce( + return torch.distributed.all_reduce( torch.tensor(local_arrived), torch.distributed.ReduceOp.MIN).item() - if self._state == _State.GLOBAL_UNBLOCK_BARRIER and global_arrived: - self._change_state(original=_State.GLOBAL_UNBLOCK_BARRIER, target=_State.UNBLOCKED) - TODO + def _handle_arrive_unblock_barrier(self): + self._change_state(original=_State.GLOBAL_UNBLOCK_BARRIER, target=_State.UNBLOCKED) + TODO def _change_state(self, original: "_State", target: "_State"): assert self._state == original, f"{self._state=} {original=} {target=}" From 47cba32d35f03d0d8a1eb91b18375f7829b3ae4e Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 23 Mar 2025 18:22:43 +0800 Subject: [PATCH 0026/1089] more --- python/sglang/srt/managers/scheduler_input_blocker.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/scheduler_input_blocker.py b/python/sglang/srt/managers/scheduler_input_blocker.py index 1159b0f08ba..9befeb9bb3f 100644 --- a/python/sglang/srt/managers/scheduler_input_blocker.py +++ b/python/sglang/srt/managers/scheduler_input_blocker.py @@ -34,7 +34,7 @@ def handle(self, recv_reqs: Optional[List[Any]]): global_arrived_unblock_barrier = self._compute_global_unblock_barrier() if self._state == _State.GLOBAL_UNBLOCK_BARRIER and global_arrived_unblock_barrier: - self._handle_arrive_unblock_barrier() + output_reqs += self._handle_arrive_unblock_barrier() if not self._noop: return output_reqs @@ -73,7 +73,9 @@ def _compute_global_unblock_barrier(self): def _handle_arrive_unblock_barrier(self): self._change_state(original=_State.GLOBAL_UNBLOCK_BARRIER, target=_State.UNBLOCKED) - TODO + output_reqs = [*self._pending_reqs] + self._pending_reqs.clear() + return output_reqs def _change_state(self, original: "_State", target: "_State"): assert self._state == original, f"{self._state=} {original=} {target=}" From b8ec9cc3f5afa6f9132ff67c4305efcf5b00b8f9 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 23 Mar 2025 18:25:12 +0800 Subject: [PATCH 0027/1089] fmt --- python/sglang/srt/managers/scheduler.py | 34 +++++++++++-------- .../srt/managers/scheduler_input_blocker.py | 19 ++++++++--- .../sglang/srt/managers/tokenizer_manager.py | 8 +++-- python/sglang/srt/utils.py | 22 ++++++------ 4 files changed, 50 insertions(+), 33 deletions(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 1f681ef7181..460822817d7 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -32,6 +32,8 @@ import setproctitle import torch import zmq +from torch.distributed import barrier + from sglang.global_config import global_config from sglang.srt.configs.model_config import ModelConfig from sglang.srt.constrained.base_grammar_backend import create_grammar_backend @@ -113,6 +115,7 @@ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.utils import ( + ENABLE_COLOCATED_BATCH_GEN, DynamicGradMode, broadcast_pyobj, configure_logger, @@ -123,10 +126,9 @@ pyspy_dump_schedulers, set_gpu_proc_affinity, set_random_seed, - suppress_other_loggers, ENABLE_COLOCATED_BATCH_GEN, + suppress_other_loggers, ) from sglang.utils import TypeBasedDispatcher, get_exception_traceback -from torch.distributed import barrier logger = logging.getLogger(__name__) @@ -356,8 +358,8 @@ def __init__( 1.0, ) self.new_token_ratio_decay = ( - self.init_new_token_ratio - self.min_new_token_ratio - ) / global_config.default_new_token_ratio_decay_steps + self.init_new_token_ratio - self.min_new_token_ratio + ) / global_config.default_new_token_ratio_decay_steps self.new_token_ratio = self.init_new_token_ratio # Init watchdog thread @@ -371,7 +373,11 @@ def __init__( enable=server_args.enable_memory_saver ) - self.input_blocker = SchedulerInputBlocker(noop=self.attn_tp_rank != 0) if ENABLE_COLOCATED_BATCH_GEN else None + self.input_blocker = ( + SchedulerInputBlocker(noop=self.attn_tp_rank != 0) + if ENABLE_COLOCATED_BATCH_GEN + else None + ) # Init profiler self.torch_profiler = None @@ -1237,10 +1243,10 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: if ( self.lora_paths and len( - lora_set - | set([req.lora_path for req in adder.can_run_list]) - | set([req.lora_path]) - ) + lora_set + | set([req.lora_path for req in adder.can_run_list]) + | set([req.lora_path]) + ) > self.max_loras_per_batch ): self.running_batch.batch_is_full = True @@ -1265,9 +1271,9 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: self.running_batch.batch_is_full = len( adder.can_run_list ) > 0 or ( - self.running_batch is not None - and not self.running_batch.is_empty() - ) + self.running_batch is not None + and not self.running_batch.is_empty() + ) else: self.running_batch.batch_is_full = True break @@ -1471,8 +1477,8 @@ def prepare_dp_attn_batch(self, local_batch: ScheduleBatch): # We should have at least 1 token for sample in every case. max(extend_len - logprob_start_len, 1) for logprob_start_len, extend_len in zip( - local_batch.extend_logprob_start_lens, local_batch.extend_lens - ) + local_batch.extend_logprob_start_lens, local_batch.extend_lens + ) ] ) diff --git a/python/sglang/srt/managers/scheduler_input_blocker.py b/python/sglang/srt/managers/scheduler_input_blocker.py index 9befeb9bb3f..0c9d04c89ad 100644 --- a/python/sglang/srt/managers/scheduler_input_blocker.py +++ b/python/sglang/srt/managers/scheduler_input_blocker.py @@ -12,9 +12,10 @@ # limitations under the License. # ============================================================================== from enum import Enum, auto -from typing import List, Optional, Any +from typing import Any, List, Optional import torch + from sglang.srt.managers.io_struct import BlockReqInput, BlockReqType @@ -33,7 +34,10 @@ def handle(self, recv_reqs: Optional[List[Any]]): output_reqs += self._handle_recv_req(recv_req) global_arrived_unblock_barrier = self._compute_global_unblock_barrier() - if self._state == _State.GLOBAL_UNBLOCK_BARRIER and global_arrived_unblock_barrier: + if ( + self._state == _State.GLOBAL_UNBLOCK_BARRIER + and global_arrived_unblock_barrier + ): output_reqs += self._handle_arrive_unblock_barrier() if not self._noop: @@ -60,7 +64,9 @@ def _execute_block_req(self): self._change_state(original=_State.UNBLOCKED, target=_State.BLOCKED) def _execute_unblock_req(self): - self._change_state(original=_State.BLOCKED, target=_State.GLOBAL_UNBLOCK_BARRIER) + self._change_state( + original=_State.BLOCKED, target=_State.GLOBAL_UNBLOCK_BARRIER + ) def _compute_global_unblock_barrier(self): if self._noop: @@ -69,10 +75,13 @@ def _compute_global_unblock_barrier(self): local_arrived = self._state == _State.GLOBAL_UNBLOCK_BARRIER return torch.distributed.all_reduce( - torch.tensor(local_arrived), torch.distributed.ReduceOp.MIN).item() + torch.tensor(local_arrived), torch.distributed.ReduceOp.MIN + ).item() def _handle_arrive_unblock_barrier(self): - self._change_state(original=_State.GLOBAL_UNBLOCK_BARRIER, target=_State.UNBLOCKED) + self._change_state( + original=_State.GLOBAL_UNBLOCK_BARRIER, target=_State.UNBLOCKED + ) output_reqs = [*self._pending_reqs] self._pending_reqs.clear() return output_reqs diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 7043967763c..e5e7bd981c1 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -45,6 +45,7 @@ import zmq import zmq.asyncio from fastapi import BackgroundTasks + from sglang.srt.aio_rwlock import RWLock from sglang.srt.configs.model_config import ModelConfig from sglang.srt.disaggregation.conn import KVBootstrapServer @@ -97,10 +98,11 @@ from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import ( + ENABLE_COLOCATED_BATCH_GEN, dataclass_to_string_truncated, get_bool_env_var, get_zmq_socket, - kill_process_tree, ENABLE_COLOCATED_BATCH_GEN, + kill_process_tree, ) from sglang.utils import TypeBasedDispatcher, get_exception_traceback @@ -947,8 +949,8 @@ def _handle_batch_output( elif isinstance(recv_obj, BatchTokenIDOut): if self.server_args.stream_output and state.obj.stream: output_token_ids = recv_obj.output_ids[i][ - state.last_output_offset: - ] + state.last_output_offset : + ] state.last_output_offset = len(recv_obj.output_ids[i]) else: output_token_ids = recv_obj.output_ids[i] diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index ab5233572a1..0342e6505b9 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -855,10 +855,10 @@ def get_zmq_socket( context: zmq.Context, socket_type: zmq.SocketType, endpoint: str, bind: bool ): mem = psutil.virtual_memory() - total_mem = mem.total / 1024 ** 3 - available_mem = mem.available / 1024 ** 3 + total_mem = mem.total / 1024**3 + available_mem = mem.available / 1024**3 if total_mem > 32 and available_mem > 16: - buf_size = int(0.5 * 1024 ** 3) + buf_size = int(0.5 * 1024**3) else: buf_size = -1 @@ -1419,10 +1419,10 @@ def dataclass_to_string_truncated( return ( "{" + ", ".join( - f"'{k}': {dataclass_to_string_truncated(v, max_length)}" - for k, v in data.items() - if k not in skip_names - ) + f"'{k}': {dataclass_to_string_truncated(v, max_length)}" + for k, v in data.items() + if k not in skip_names + ) + "}" ) elif dataclasses.is_dataclass(data): @@ -1430,10 +1430,10 @@ def dataclass_to_string_truncated( return ( f"{data.__class__.__name__}(" + ", ".join( - f"{f.name}={dataclass_to_string_truncated(getattr(data, f.name), max_length)}" - for f in fields - if f.name not in skip_names - ) + f"{f.name}={dataclass_to_string_truncated(getattr(data, f.name), max_length)}" + for f in fields + if f.name not in skip_names + ) + ")" ) else: From 73eee1fc061c59864a4955541bc34fc68e66bce9 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 23 Mar 2025 20:15:55 +0800 Subject: [PATCH 0028/1089] more --- .../srt/managers/data_parallel_controller.py | 4 +++ python/sglang/srt/managers/scheduler.py | 4 +-- .../srt/managers/scheduler_input_blocker.py | 33 ++++++++++++++----- .../sglang/srt/managers/tokenizer_manager.py | 6 ++-- python/sglang/srt/utils.py | 5 ++- 5 files changed, 36 insertions(+), 16 deletions(-) diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index fb0264a6ea9..0dfdbb31805 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -25,6 +25,7 @@ from sglang.srt.layers.dp_attention import compute_dp_attention_world_info from sglang.srt.managers.io_struct import ( + BlockReqInput, TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, ) @@ -243,6 +244,9 @@ def event_loop(self): ), ): self.dispatching(recv_req) + elif isinstance(recv_req, BlockReqInput): + for worker in self.workers: + worker.send_pyobj(recv_req) else: # Send other control messages to first worker of tp group for worker in self.workers[:: self.control_message_step]: diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 460822817d7..2ea453077d6 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -115,11 +115,11 @@ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.utils import ( - ENABLE_COLOCATED_BATCH_GEN, DynamicGradMode, broadcast_pyobj, configure_logger, crash_on_warnings, + enable_colocated_batch_gen, get_bool_env_var, get_zmq_socket, kill_itself_when_parent_died, @@ -375,7 +375,7 @@ def __init__( self.input_blocker = ( SchedulerInputBlocker(noop=self.attn_tp_rank != 0) - if ENABLE_COLOCATED_BATCH_GEN + if enable_colocated_batch_gen() else None ) diff --git a/python/sglang/srt/managers/scheduler_input_blocker.py b/python/sglang/srt/managers/scheduler_input_blocker.py index 0c9d04c89ad..b21fda4d16d 100644 --- a/python/sglang/srt/managers/scheduler_input_blocker.py +++ b/python/sglang/srt/managers/scheduler_input_blocker.py @@ -16,6 +16,7 @@ import torch +from sglang.srt.distributed import get_tensor_model_parallel_rank from sglang.srt.managers.io_struct import BlockReqInput, BlockReqType @@ -44,6 +45,10 @@ def handle(self, recv_reqs: Optional[List[Any]]): return output_reqs def _handle_recv_req(self, recv_req): + print( + f"hi [{get_tensor_model_parallel_rank()}] handle_recv_req START {type(recv_req)=}", + flush=True, + ) if isinstance(recv_req, BlockReqInput): if recv_req.type == BlockReqType.BLOCK: self._execute_block_req() @@ -61,24 +66,36 @@ def _handle_recv_req(self, recv_req): return [] def _execute_block_req(self): + print( + f"hi [{get_tensor_model_parallel_rank()}] execute_block_req START", + flush=True, + ) self._change_state(original=_State.UNBLOCKED, target=_State.BLOCKED) def _execute_unblock_req(self): + print( + f"hi [{get_tensor_model_parallel_rank()}] execute_unblock_req START", + flush=True, + ) self._change_state( original=_State.BLOCKED, target=_State.GLOBAL_UNBLOCK_BARRIER ) def _compute_global_unblock_barrier(self): - if self._noop: - local_arrived = True - else: - local_arrived = self._state == _State.GLOBAL_UNBLOCK_BARRIER - - return torch.distributed.all_reduce( - torch.tensor(local_arrived), torch.distributed.ReduceOp.MIN - ).item() + local_arrived = self._noop or (self._state == _State.GLOBAL_UNBLOCK_BARRIER) + # print(f'hi [{get_tensor_model_parallel_rank()}] _compute_global_unblock_barrier START {local_arrived=}', + # flush=True) + global_arrived = torch.tensor(local_arrived).cuda() + torch.distributed.all_reduce(global_arrived, torch.distributed.ReduceOp.MIN) + # print(f'hi [{get_tensor_model_parallel_rank()}] _compute_global_unblock_barrier MIDDLE {global_arrived=}', + # flush=True) + return global_arrived.cpu().item() def _handle_arrive_unblock_barrier(self): + print( + f"hi [{get_tensor_model_parallel_rank()}] _handle_arrive_unblock_barrier START {[type(x) for x in self._pending_reqs]=}", + flush=True, + ) self._change_state( original=_State.GLOBAL_UNBLOCK_BARRIER, target=_State.UNBLOCKED ) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index e5e7bd981c1..a1ad009d862 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -98,8 +98,8 @@ from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import ( - ENABLE_COLOCATED_BATCH_GEN, dataclass_to_string_truncated, + enable_colocated_batch_gen, get_bool_env_var, get_zmq_socket, kill_process_tree, @@ -545,7 +545,7 @@ async def _handle_batch_request( rids = [] if getattr(obj, "parallel_sample_num", 1) == 1: # Send all requests - if ENABLE_COLOCATED_BATCH_GEN: + if enable_colocated_batch_gen(): self._send_block_request(BlockReqType.BLOCK) for i in range(batch_size): tmp_obj = obj[i] @@ -553,7 +553,7 @@ async def _handle_batch_request( self._send_one_request(tmp_obj, tokenized_obj, created_time) generators.append(self._wait_one_response(tmp_obj, request)) rids.append(tmp_obj.rid) - if ENABLE_COLOCATED_BATCH_GEN: + if enable_colocated_batch_gen(): self._send_block_request(BlockReqType.UNBLOCK) else: # FIXME: When using batch and parallel_sample_num together, the perf is not optimal. diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 0342e6505b9..e9ee7644d4a 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -188,9 +188,8 @@ def clone(self) -> "DynamicGradMode": return self.__class__() -ENABLE_COLOCATED_BATCH_GEN = get_bool_env_var( - "SGLANG_ENABLE_COLOCATED_BATCH_GEN", "false" -) +def enable_colocated_batch_gen(): + return get_bool_env_var("SGLANG_ENABLE_COLOCATED_BATCH_GEN", "false") def enable_show_time_cost(): From 55f51a2c3816f74b36e9c980f37da7087eb2ef04 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 23 Mar 2025 20:16:24 +0800 Subject: [PATCH 0029/1089] more --- .../srt/managers/scheduler_input_blocker.py | 20 ------------------- 1 file changed, 20 deletions(-) diff --git a/python/sglang/srt/managers/scheduler_input_blocker.py b/python/sglang/srt/managers/scheduler_input_blocker.py index b21fda4d16d..a91d199f30c 100644 --- a/python/sglang/srt/managers/scheduler_input_blocker.py +++ b/python/sglang/srt/managers/scheduler_input_blocker.py @@ -45,10 +45,6 @@ def handle(self, recv_reqs: Optional[List[Any]]): return output_reqs def _handle_recv_req(self, recv_req): - print( - f"hi [{get_tensor_model_parallel_rank()}] handle_recv_req START {type(recv_req)=}", - flush=True, - ) if isinstance(recv_req, BlockReqInput): if recv_req.type == BlockReqType.BLOCK: self._execute_block_req() @@ -66,36 +62,20 @@ def _handle_recv_req(self, recv_req): return [] def _execute_block_req(self): - print( - f"hi [{get_tensor_model_parallel_rank()}] execute_block_req START", - flush=True, - ) self._change_state(original=_State.UNBLOCKED, target=_State.BLOCKED) def _execute_unblock_req(self): - print( - f"hi [{get_tensor_model_parallel_rank()}] execute_unblock_req START", - flush=True, - ) self._change_state( original=_State.BLOCKED, target=_State.GLOBAL_UNBLOCK_BARRIER ) def _compute_global_unblock_barrier(self): local_arrived = self._noop or (self._state == _State.GLOBAL_UNBLOCK_BARRIER) - # print(f'hi [{get_tensor_model_parallel_rank()}] _compute_global_unblock_barrier START {local_arrived=}', - # flush=True) global_arrived = torch.tensor(local_arrived).cuda() torch.distributed.all_reduce(global_arrived, torch.distributed.ReduceOp.MIN) - # print(f'hi [{get_tensor_model_parallel_rank()}] _compute_global_unblock_barrier MIDDLE {global_arrived=}', - # flush=True) return global_arrived.cpu().item() def _handle_arrive_unblock_barrier(self): - print( - f"hi [{get_tensor_model_parallel_rank()}] _handle_arrive_unblock_barrier START {[type(x) for x in self._pending_reqs]=}", - flush=True, - ) self._change_state( original=_State.GLOBAL_UNBLOCK_BARRIER, target=_State.UNBLOCKED ) From 1e4b2e2ae17f46ae2ce38406d5097bc230aaabe2 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 23 Mar 2025 20:24:16 +0800 Subject: [PATCH 0030/1089] more --- python/sglang/srt/managers/scheduler.py | 2 +- python/sglang/srt/managers/scheduler_input_blocker.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 2ea453077d6..fce8273bfce 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -374,7 +374,7 @@ def __init__( ) self.input_blocker = ( - SchedulerInputBlocker(noop=self.attn_tp_rank != 0) + SchedulerInputBlocker(server_args, noop=self.attn_tp_rank != 0) if enable_colocated_batch_gen() else None ) diff --git a/python/sglang/srt/managers/scheduler_input_blocker.py b/python/sglang/srt/managers/scheduler_input_blocker.py index a91d199f30c..f55c6fe7372 100644 --- a/python/sglang/srt/managers/scheduler_input_blocker.py +++ b/python/sglang/srt/managers/scheduler_input_blocker.py @@ -16,15 +16,18 @@ import torch -from sglang.srt.distributed import get_tensor_model_parallel_rank +from sglang import ServerArgs from sglang.srt.managers.io_struct import BlockReqInput, BlockReqType class SchedulerInputBlocker: - def __init__(self, noop: bool): + def __init__(self, server_args: ServerArgs, noop: bool): self._state = _State.UNBLOCKED self._pending_reqs = [] self._noop = noop + assert ( + server_args.disable_overlap_schedule + ), "SchedulerInputBlocker requires overlap scheduler to be disabled" def handle(self, recv_reqs: Optional[List[Any]]): assert (recv_reqs is None) == self._noop From 1dc81e6af33121c8546ec94285400939458e04be Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 28 Mar 2025 10:17:31 +0800 Subject: [PATCH 0031/1089] more --- python/sglang/srt/server_args.py | 75 +++++++++++++++++--------------- 1 file changed, 41 insertions(+), 34 deletions(-) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 6a0166b41af..10ad064d023 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -180,6 +180,7 @@ class ServerArgs: enable_flashmla: bool = False flashinfer_mla_disable_ragged: bool = False warmups: Optional[str] = None + moe_dense_tp_size: Optional[int] = None # Debug tensor dumps debug_tensor_dump_output_folder: Optional[str] = None @@ -378,8 +379,8 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.tokenizer_mode, choices=["auto", "slow"], help="Tokenizer mode. 'auto' will use the fast " - "tokenizer if available, and 'slow' will " - "always use the slow tokenizer.", + "tokenizer if available, and 'slow' will " + "always use the slow tokenizer.", ) parser.add_argument( "--skip-tokenizer-init", @@ -403,21 +404,21 @@ def add_cli_args(parser: argparse.ArgumentParser): "remote", ], help="The format of the model weights to load. " - '"auto" will try to load the weights in the safetensors format ' - "and fall back to the pytorch bin format if safetensors format " - "is not available. " - '"pt" will load the weights in the pytorch bin format. ' - '"safetensors" will load the weights in the safetensors format. ' - '"npcache" will load the weights in pytorch format and store ' - "a numpy cache to speed up the loading. " - '"dummy" will initialize the weights with random values, ' - "which is mainly for profiling." - '"gguf" will load the weights in the gguf format. ' - '"bitsandbytes" will load the weights using bitsandbytes ' - "quantization." - '"layered" loads weights layer by layer so that one can quantize a ' - "layer before loading another to make the peak memory envelope " - "smaller.", + '"auto" will try to load the weights in the safetensors format ' + "and fall back to the pytorch bin format if safetensors format " + "is not available. " + '"pt" will load the weights in the pytorch bin format. ' + '"safetensors" will load the weights in the safetensors format. ' + '"npcache" will load the weights in pytorch format and store ' + "a numpy cache to speed up the loading. " + '"dummy" will initialize the weights with random values, ' + "which is mainly for profiling." + '"gguf" will load the weights in the gguf format. ' + '"bitsandbytes" will load the weights using bitsandbytes ' + "quantization." + '"layered" loads weights layer by layer so that one can quantize a ' + "layer before loading another to make the peak memory envelope " + "smaller.", ) parser.add_argument( "--trust-remote-code", @@ -430,13 +431,13 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.dtype, choices=["auto", "half", "float16", "bfloat16", "float", "float32"], help="Data type for model weights and activations.\n\n" - '* "auto" will use FP16 precision for FP32 and FP16 models, and ' - "BF16 precision for BF16 models.\n" - '* "half" for FP16. Recommended for AWQ quantization.\n' - '* "float16" is the same as "half".\n' - '* "bfloat16" for a balance between precision and range.\n' - '* "float" is shorthand for FP32 precision.\n' - '* "float32" for FP32 precision.', + '* "auto" will use FP16 precision for FP32 and FP16 models, and ' + "BF16 precision for BF16 models.\n" + '* "half" for FP16. Recommended for AWQ quantization.\n' + '* "float16" is the same as "half".\n' + '* "bfloat16" for a balance between precision and range.\n' + '* "float" is shorthand for FP32 precision.\n' + '* "float32" for FP32 precision.', ) parser.add_argument( "--kv-cache-dtype", @@ -469,9 +470,9 @@ def add_cli_args(parser: argparse.ArgumentParser): type=nullable_str, default=None, help="Path to the JSON file containing the KV cache " - "scaling factors. This should generally be supplied, when " - "KV cache dtype is FP8. Otherwise, KV cache scaling factors " - "default to 1.0, which may cause accuracy issues. ", + "scaling factors. This should generally be supplied, when " + "KV cache dtype is FP8. Otherwise, KV cache scaling factors " + "default to 1.0, which may cause accuracy issues. ", ) parser.add_argument( "--context-length", @@ -513,8 +514,8 @@ def add_cli_args(parser: argparse.ArgumentParser): type=str, default=None, help="The specific model version to use. It can be a branch " - "name, a tag name, or a commit id. If unspecified, will use " - "the default version.", + "name, a tag name, or a commit id. If unspecified, will use " + "the default version.", ) # Memory and scheduling parser.add_argument( @@ -534,7 +535,7 @@ def add_cli_args(parser: argparse.ArgumentParser): type=int, default=ServerArgs.max_total_tokens, help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. " - "This option is typically used for development and debugging purposes.", + "This option is typically used for development and debugging purposes.", ) parser.add_argument( "--chunked-prefill-size", @@ -999,7 +1000,7 @@ def add_cli_args(parser: argparse.ArgumentParser): "--triton-attention-reduce-in-fp32", action="store_true", help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16." - "This only affects Triton attention kernels.", + "This only affects Triton attention kernels.", ) parser.add_argument( "--triton-attention-num-kv-splits", @@ -1012,8 +1013,8 @@ def add_cli_args(parser: argparse.ArgumentParser): type=int, default=ServerArgs.num_continuous_decode_steps, help="Run multiple continuous decoding steps to reduce scheduling overhead. " - "This can potentially increase throughput but may also increase time-to-first-token latency. " - "The default value is 1, meaning only run one decoding step at a time.", + "This can potentially increase throughput but may also increase time-to-first-token latency. " + "The default value is 1, meaning only run one decoding step at a time.", ) parser.add_argument( "--delete-ckpt-after-loading", @@ -1059,6 +1060,12 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Enabling DeepEP MoE implementation for EP MoE.", ) + parser.add_argument( + "--moe-dense-tp-size", + type=int, + default=ServerArgs.moe_dense_tp_size, + help="tp_size for MoE dense MLP layers", + ) # Server warmups parser.add_argument( @@ -1066,7 +1073,7 @@ def add_cli_args(parser: argparse.ArgumentParser): type=str, required=False, help="Specify custom warmup functions (csv) to run before server starts eg. --warmups=warmup_name1,warmup_name2 " - "will run the functions `warmup_name1` and `warmup_name2` specified in warmup.py before the server starts listening for requests", + "will run the functions `warmup_name1` and `warmup_name2` specified in warmup.py before the server starts listening for requests", ) # Debug tensor dumps From 2f8a58a3f794ac1aae5a2d6d0a529e7ae57c3441 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 28 Mar 2025 10:17:54 +0800 Subject: [PATCH 0032/1089] more --- python/sglang/srt/models/deepseek_v2.py | 50 ++++++++++++------------- 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 4b733a67c55..b54a1884478 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -21,9 +21,6 @@ import torch import torch.nn.functional as F -from torch import nn -from transformers import PretrainedConfig - from sglang.srt.distributed import ( get_tensor_model_parallel_world_size, parallel_state, @@ -74,6 +71,8 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import add_prefix, is_cuda, is_cuda_available, is_hip +from torch import nn +from transformers import PretrainedConfig _is_hip = is_hip() _is_cuda = is_cuda() @@ -373,7 +372,7 @@ def __init__( self.num_heads = num_heads assert num_heads % attn_tp_size == 0 self.num_local_heads = num_heads // attn_tp_size - self.scaling = self.qk_head_dim**-0.5 + self.scaling = self.qk_head_dim ** -0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings @@ -485,12 +484,12 @@ def forward( kv = self.kv_b_proj(kv_a)[0] kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - k_pe = latent_cache[:, :, self.kv_lora_rank :] + k_pe = latent_cache[:, :, self.kv_lora_rank:] q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - q[..., self.qk_nope_head_dim :] = q_pe + q[..., self.qk_nope_head_dim:] = q_pe k = torch.empty_like(q) k[..., : self.qk_nope_head_dim] = k_nope - k[..., self.qk_nope_head_dim :] = k_pe + k[..., self.qk_nope_head_dim:] = k_pe q = torch.nn.functional.pad(q, [0, 256 - self.qk_head_dim], value=0).view( -1, self.num_local_heads * 256 ) @@ -502,8 +501,8 @@ def forward( ) attn_output = self.attn(q, k, v, forward_batch) attn_output = attn_output.view(-1, self.num_local_heads, 256)[ - ..., : self.v_head_dim - ].reshape(-1, self.num_local_heads * self.v_head_dim) + ..., : self.v_head_dim + ].reshape(-1, self.num_local_heads * self.v_head_dim) output, _ = self.o_proj(attn_output) return output @@ -544,7 +543,7 @@ def __init__( self.num_heads = num_heads assert num_heads % attn_tp_size == 0 self.num_local_heads = num_heads // attn_tp_size - self.scaling = self.qk_head_dim**-0.5 + self.scaling = self.qk_head_dim ** -0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings @@ -726,16 +725,16 @@ def forward_normal( kv = self.kv_b_proj(kv_a)[0] kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) k_nope = kv[..., : self.qk_nope_head_dim] - v = kv[..., self.qk_nope_head_dim :] - k_pe = latent_cache[:, :, self.kv_lora_rank :] + v = kv[..., self.qk_nope_head_dim:] + k_pe = latent_cache[:, :, self.kv_lora_rank:] q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - q[..., self.qk_nope_head_dim :] = q_pe + q[..., self.qk_nope_head_dim:] = q_pe k = torch.empty_like(q) k[..., : self.qk_nope_head_dim] = k_nope - k[..., self.qk_nope_head_dim :] = k_pe + k[..., self.qk_nope_head_dim:] = k_pe latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1) - latent_cache[:, :, self.kv_lora_rank :] = k_pe + latent_cache[:, :, self.kv_lora_rank:] = k_pe # Save latent cache forward_batch.token_to_kv_pool.set_kv_buffer( @@ -788,11 +787,11 @@ def forward_absorb( v_input = self.kv_a_layernorm(v_input.contiguous()).unsqueeze(1) k_input = latent_cache.unsqueeze(1) k_input[..., : self.kv_lora_rank] = v_input - k_pe = k_input[..., self.kv_lora_rank :] + k_pe = k_input[..., self.kv_lora_rank:] q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - q_input[..., self.kv_lora_rank :] = q_pe - k_input[..., self.kv_lora_rank :] = k_pe + q_input[..., self.kv_lora_rank:] = q_pe + k_input[..., self.kv_lora_rank:] = k_pe attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch) attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank) @@ -868,15 +867,15 @@ def forward_absorb_fused_mla_rope( k_input[..., : self.kv_lora_rank] = v_input if not enable_rope_fusion: - k_pe = k_input[..., self.kv_lora_rank :] + k_pe = k_input[..., self.kv_lora_rank:] q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - q_input[..., self.kv_lora_rank :] = q_pe - k_input[..., self.kv_lora_rank :] = k_pe + q_input[..., self.kv_lora_rank:] = q_pe + k_input[..., self.kv_lora_rank:] = k_pe k_pe_output = None else: - k_pe_output = torch.empty_like(k_input[..., self.kv_lora_rank :]) + k_pe_output = torch.empty_like(k_input[..., self.kv_lora_rank:]) - q_input[..., self.kv_lora_rank :] = q_pe + q_input[..., self.kv_lora_rank:] = q_pe # attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch) # Use Fused ROPE with use_rope=OFF. @@ -933,7 +932,7 @@ def forward_absorb_fused_mla_rope( ) if enable_rope_fusion: - k_input[..., self.kv_lora_rank :] = k_pe_output + k_input[..., self.kv_lora_rank:] = k_pe_output forward_batch.token_to_kv_pool.set_kv_buffer( self.attn_mqa, forward_batch.out_cache_loc, k_input, None ) @@ -1049,6 +1048,8 @@ def is_sparse_layer(l: int): hidden_act=config.hidden_act, quant_config=quant_config, prefix=add_prefix("mlp", prefix), + tp_rank=TODO, + tp_size=TODO, ) self.is_sparse = False @@ -1236,7 +1237,6 @@ def forward_deepep( class DeepseekV2Model(nn.Module): - fall_back_to_pt_during_load = False def __init__( From f9b69206ecfc5659c4574387034e919ab58107e7 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 28 Mar 2025 10:18:47 +0800 Subject: [PATCH 0033/1089] more --- python/sglang/srt/managers/schedule_batch.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 45e1d4be271..1c5f5da7204 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -78,6 +78,7 @@ "disable_radix_cache": ServerArgs.disable_radix_cache, "flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged, "chunked_prefill_size": ServerArgs.chunked_prefill_size, + "moe_dense_tp_size": ServerArgs.moe_dense_tp_size, } logger = logging.getLogger(__name__) From 17813196e50c8ca21dc2e8705c3a88180d2c0f7a Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 28 Mar 2025 10:19:21 +0800 Subject: [PATCH 0034/1089] more --- python/sglang/srt/managers/schedule_batch.py | 2 +- python/sglang/srt/model_executor/model_runner.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 1c5f5da7204..38d422d481b 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -77,8 +77,8 @@ "enable_flashmla": ServerArgs.enable_flashmla, "disable_radix_cache": ServerArgs.disable_radix_cache, "flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged, - "chunked_prefill_size": ServerArgs.chunked_prefill_size, "moe_dense_tp_size": ServerArgs.moe_dense_tp_size, + "chunked_prefill_size": ServerArgs.chunked_prefill_size, } logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 8e02172772c..6221f826175 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -24,7 +24,6 @@ import torch import torch.distributed as dist - from sglang.srt.configs.device_config import DeviceConfig from sglang.srt.configs.load_config import LoadConfig from sglang.srt.configs.model_config import AttentionArch, ModelConfig @@ -154,13 +153,14 @@ def __init__( "enable_flashmla": server_args.enable_flashmla, "disable_radix_cache": server_args.disable_radix_cache, "flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged, + "moe_dense_tp_size": server_args.moe_dense_tp_size, "debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder, "debug_tensor_dump_inject": server_args.debug_tensor_dump_inject, } ) # CPU offload - set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3)) + set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024 ** 3)) # Get memory before model loading min_per_gpu_memory = self.init_torch_distributed() @@ -902,7 +902,7 @@ def init_double_sparsity_channel_config(self, selected_channel): key = "model.layers." + str(i) + ".self_attn" + selected_channel self.sorted_channels.append( torch.tensor(channel_config[key])[ - :, : self.server_args.ds_heavy_channel_num + :, : self.server_args.ds_heavy_channel_num ] .contiguous() .cuda() From 5a751e1a91efecd18c1c5d4a335222ef511b1d8b Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 28 Mar 2025 10:20:39 +0800 Subject: [PATCH 0035/1089] more --- python/sglang/srt/models/deepseek_v2.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index b54a1884478..57e19ae8148 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1042,14 +1042,15 @@ def is_sparse_layer(l: int): ) self.is_sparse = True else: + moe_dense_tp_size = global_server_args_dict["moe_dense_tp_size"] self.mlp = DeepseekV2MLP( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, prefix=add_prefix("mlp", prefix), - tp_rank=TODO, - tp_size=TODO, + tp_rank=(get_tensor_model_parallel_rank() % moe_dense_tp_size) if moe_dense_tp_size else None, + tp_size=moe_dense_tp_size, ) self.is_sparse = False From cec9c35af20f4f48b5bf844d9e2522381d433e4e Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 28 Mar 2025 10:20:57 +0800 Subject: [PATCH 0036/1089] fmt --- .../sglang/srt/model_executor/model_runner.py | 5 +- python/sglang/srt/models/deepseek_v2.py | 53 ++++++++------- python/sglang/srt/server_args.py | 68 +++++++++---------- 3 files changed, 66 insertions(+), 60 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 6221f826175..533e83322c6 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -24,6 +24,7 @@ import torch import torch.distributed as dist + from sglang.srt.configs.device_config import DeviceConfig from sglang.srt.configs.load_config import LoadConfig from sglang.srt.configs.model_config import AttentionArch, ModelConfig @@ -160,7 +161,7 @@ def __init__( ) # CPU offload - set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024 ** 3)) + set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3)) # Get memory before model loading min_per_gpu_memory = self.init_torch_distributed() @@ -902,7 +903,7 @@ def init_double_sparsity_channel_config(self, selected_channel): key = "model.layers." + str(i) + ".self_attn" + selected_channel self.sorted_channels.append( torch.tensor(channel_config[key])[ - :, : self.server_args.ds_heavy_channel_num + :, : self.server_args.ds_heavy_channel_num ] .contiguous() .cuda() diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 57e19ae8148..43766c6e710 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -21,6 +21,9 @@ import torch import torch.nn.functional as F +from torch import nn +from transformers import PretrainedConfig + from sglang.srt.distributed import ( get_tensor_model_parallel_world_size, parallel_state, @@ -71,8 +74,6 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import add_prefix, is_cuda, is_cuda_available, is_hip -from torch import nn -from transformers import PretrainedConfig _is_hip = is_hip() _is_cuda = is_cuda() @@ -372,7 +373,7 @@ def __init__( self.num_heads = num_heads assert num_heads % attn_tp_size == 0 self.num_local_heads = num_heads // attn_tp_size - self.scaling = self.qk_head_dim ** -0.5 + self.scaling = self.qk_head_dim**-0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings @@ -484,12 +485,12 @@ def forward( kv = self.kv_b_proj(kv_a)[0] kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - k_pe = latent_cache[:, :, self.kv_lora_rank:] + k_pe = latent_cache[:, :, self.kv_lora_rank :] q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - q[..., self.qk_nope_head_dim:] = q_pe + q[..., self.qk_nope_head_dim :] = q_pe k = torch.empty_like(q) k[..., : self.qk_nope_head_dim] = k_nope - k[..., self.qk_nope_head_dim:] = k_pe + k[..., self.qk_nope_head_dim :] = k_pe q = torch.nn.functional.pad(q, [0, 256 - self.qk_head_dim], value=0).view( -1, self.num_local_heads * 256 ) @@ -501,8 +502,8 @@ def forward( ) attn_output = self.attn(q, k, v, forward_batch) attn_output = attn_output.view(-1, self.num_local_heads, 256)[ - ..., : self.v_head_dim - ].reshape(-1, self.num_local_heads * self.v_head_dim) + ..., : self.v_head_dim + ].reshape(-1, self.num_local_heads * self.v_head_dim) output, _ = self.o_proj(attn_output) return output @@ -543,7 +544,7 @@ def __init__( self.num_heads = num_heads assert num_heads % attn_tp_size == 0 self.num_local_heads = num_heads // attn_tp_size - self.scaling = self.qk_head_dim ** -0.5 + self.scaling = self.qk_head_dim**-0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings @@ -725,16 +726,16 @@ def forward_normal( kv = self.kv_b_proj(kv_a)[0] kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) k_nope = kv[..., : self.qk_nope_head_dim] - v = kv[..., self.qk_nope_head_dim:] - k_pe = latent_cache[:, :, self.kv_lora_rank:] + v = kv[..., self.qk_nope_head_dim :] + k_pe = latent_cache[:, :, self.kv_lora_rank :] q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - q[..., self.qk_nope_head_dim:] = q_pe + q[..., self.qk_nope_head_dim :] = q_pe k = torch.empty_like(q) k[..., : self.qk_nope_head_dim] = k_nope - k[..., self.qk_nope_head_dim:] = k_pe + k[..., self.qk_nope_head_dim :] = k_pe latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1) - latent_cache[:, :, self.kv_lora_rank:] = k_pe + latent_cache[:, :, self.kv_lora_rank :] = k_pe # Save latent cache forward_batch.token_to_kv_pool.set_kv_buffer( @@ -787,11 +788,11 @@ def forward_absorb( v_input = self.kv_a_layernorm(v_input.contiguous()).unsqueeze(1) k_input = latent_cache.unsqueeze(1) k_input[..., : self.kv_lora_rank] = v_input - k_pe = k_input[..., self.kv_lora_rank:] + k_pe = k_input[..., self.kv_lora_rank :] q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - q_input[..., self.kv_lora_rank:] = q_pe - k_input[..., self.kv_lora_rank:] = k_pe + q_input[..., self.kv_lora_rank :] = q_pe + k_input[..., self.kv_lora_rank :] = k_pe attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch) attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank) @@ -867,15 +868,15 @@ def forward_absorb_fused_mla_rope( k_input[..., : self.kv_lora_rank] = v_input if not enable_rope_fusion: - k_pe = k_input[..., self.kv_lora_rank:] + k_pe = k_input[..., self.kv_lora_rank :] q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - q_input[..., self.kv_lora_rank:] = q_pe - k_input[..., self.kv_lora_rank:] = k_pe + q_input[..., self.kv_lora_rank :] = q_pe + k_input[..., self.kv_lora_rank :] = k_pe k_pe_output = None else: - k_pe_output = torch.empty_like(k_input[..., self.kv_lora_rank:]) + k_pe_output = torch.empty_like(k_input[..., self.kv_lora_rank :]) - q_input[..., self.kv_lora_rank:] = q_pe + q_input[..., self.kv_lora_rank :] = q_pe # attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch) # Use Fused ROPE with use_rope=OFF. @@ -932,7 +933,7 @@ def forward_absorb_fused_mla_rope( ) if enable_rope_fusion: - k_input[..., self.kv_lora_rank:] = k_pe_output + k_input[..., self.kv_lora_rank :] = k_pe_output forward_batch.token_to_kv_pool.set_kv_buffer( self.attn_mqa, forward_batch.out_cache_loc, k_input, None ) @@ -1049,7 +1050,11 @@ def is_sparse_layer(l: int): hidden_act=config.hidden_act, quant_config=quant_config, prefix=add_prefix("mlp", prefix), - tp_rank=(get_tensor_model_parallel_rank() % moe_dense_tp_size) if moe_dense_tp_size else None, + tp_rank=( + (get_tensor_model_parallel_rank() % moe_dense_tp_size) + if moe_dense_tp_size + else None + ), tp_size=moe_dense_tp_size, ) self.is_sparse = False diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 10ad064d023..f72bb5cac7f 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -379,8 +379,8 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.tokenizer_mode, choices=["auto", "slow"], help="Tokenizer mode. 'auto' will use the fast " - "tokenizer if available, and 'slow' will " - "always use the slow tokenizer.", + "tokenizer if available, and 'slow' will " + "always use the slow tokenizer.", ) parser.add_argument( "--skip-tokenizer-init", @@ -404,21 +404,21 @@ def add_cli_args(parser: argparse.ArgumentParser): "remote", ], help="The format of the model weights to load. " - '"auto" will try to load the weights in the safetensors format ' - "and fall back to the pytorch bin format if safetensors format " - "is not available. " - '"pt" will load the weights in the pytorch bin format. ' - '"safetensors" will load the weights in the safetensors format. ' - '"npcache" will load the weights in pytorch format and store ' - "a numpy cache to speed up the loading. " - '"dummy" will initialize the weights with random values, ' - "which is mainly for profiling." - '"gguf" will load the weights in the gguf format. ' - '"bitsandbytes" will load the weights using bitsandbytes ' - "quantization." - '"layered" loads weights layer by layer so that one can quantize a ' - "layer before loading another to make the peak memory envelope " - "smaller.", + '"auto" will try to load the weights in the safetensors format ' + "and fall back to the pytorch bin format if safetensors format " + "is not available. " + '"pt" will load the weights in the pytorch bin format. ' + '"safetensors" will load the weights in the safetensors format. ' + '"npcache" will load the weights in pytorch format and store ' + "a numpy cache to speed up the loading. " + '"dummy" will initialize the weights with random values, ' + "which is mainly for profiling." + '"gguf" will load the weights in the gguf format. ' + '"bitsandbytes" will load the weights using bitsandbytes ' + "quantization." + '"layered" loads weights layer by layer so that one can quantize a ' + "layer before loading another to make the peak memory envelope " + "smaller.", ) parser.add_argument( "--trust-remote-code", @@ -431,13 +431,13 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.dtype, choices=["auto", "half", "float16", "bfloat16", "float", "float32"], help="Data type for model weights and activations.\n\n" - '* "auto" will use FP16 precision for FP32 and FP16 models, and ' - "BF16 precision for BF16 models.\n" - '* "half" for FP16. Recommended for AWQ quantization.\n' - '* "float16" is the same as "half".\n' - '* "bfloat16" for a balance between precision and range.\n' - '* "float" is shorthand for FP32 precision.\n' - '* "float32" for FP32 precision.', + '* "auto" will use FP16 precision for FP32 and FP16 models, and ' + "BF16 precision for BF16 models.\n" + '* "half" for FP16. Recommended for AWQ quantization.\n' + '* "float16" is the same as "half".\n' + '* "bfloat16" for a balance between precision and range.\n' + '* "float" is shorthand for FP32 precision.\n' + '* "float32" for FP32 precision.', ) parser.add_argument( "--kv-cache-dtype", @@ -470,9 +470,9 @@ def add_cli_args(parser: argparse.ArgumentParser): type=nullable_str, default=None, help="Path to the JSON file containing the KV cache " - "scaling factors. This should generally be supplied, when " - "KV cache dtype is FP8. Otherwise, KV cache scaling factors " - "default to 1.0, which may cause accuracy issues. ", + "scaling factors. This should generally be supplied, when " + "KV cache dtype is FP8. Otherwise, KV cache scaling factors " + "default to 1.0, which may cause accuracy issues. ", ) parser.add_argument( "--context-length", @@ -514,8 +514,8 @@ def add_cli_args(parser: argparse.ArgumentParser): type=str, default=None, help="The specific model version to use. It can be a branch " - "name, a tag name, or a commit id. If unspecified, will use " - "the default version.", + "name, a tag name, or a commit id. If unspecified, will use " + "the default version.", ) # Memory and scheduling parser.add_argument( @@ -535,7 +535,7 @@ def add_cli_args(parser: argparse.ArgumentParser): type=int, default=ServerArgs.max_total_tokens, help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. " - "This option is typically used for development and debugging purposes.", + "This option is typically used for development and debugging purposes.", ) parser.add_argument( "--chunked-prefill-size", @@ -1000,7 +1000,7 @@ def add_cli_args(parser: argparse.ArgumentParser): "--triton-attention-reduce-in-fp32", action="store_true", help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16." - "This only affects Triton attention kernels.", + "This only affects Triton attention kernels.", ) parser.add_argument( "--triton-attention-num-kv-splits", @@ -1013,8 +1013,8 @@ def add_cli_args(parser: argparse.ArgumentParser): type=int, default=ServerArgs.num_continuous_decode_steps, help="Run multiple continuous decoding steps to reduce scheduling overhead. " - "This can potentially increase throughput but may also increase time-to-first-token latency. " - "The default value is 1, meaning only run one decoding step at a time.", + "This can potentially increase throughput but may also increase time-to-first-token latency. " + "The default value is 1, meaning only run one decoding step at a time.", ) parser.add_argument( "--delete-ckpt-after-loading", @@ -1073,7 +1073,7 @@ def add_cli_args(parser: argparse.ArgumentParser): type=str, required=False, help="Specify custom warmup functions (csv) to run before server starts eg. --warmups=warmup_name1,warmup_name2 " - "will run the functions `warmup_name1` and `warmup_name2` specified in warmup.py before the server starts listening for requests", + "will run the functions `warmup_name1` and `warmup_name2` specified in warmup.py before the server starts listening for requests", ) # Debug tensor dumps From 7ab9a1d4a1a525dd2b1eb68060e8fb84b8166eaa Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 28 Mar 2025 10:22:46 +0800 Subject: [PATCH 0037/1089] more --- python/sglang/srt/models/deepseek_v2.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 43766c6e710..8778ec19e06 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -25,6 +25,7 @@ from transformers import PretrainedConfig from sglang.srt.distributed import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, parallel_state, tensor_model_parallel_all_reduce, From 0a7ce786d4b1343f667256ddfbfb6eb12912e51c Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 30 Mar 2025 09:27:37 +0800 Subject: [PATCH 0038/1089] more --- python/sglang/srt/models/deepseek_v2.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 3407a2134ac..3e117458eff 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1083,15 +1083,15 @@ def forward( residual: Optional[torch.Tensor], ) -> torch.Tensor: if global_server_args_dict["enable_deepep_moe"] and self.is_sparse: - return self.forward_deepep( + return self.forward_mode_mlp_one( positions, hidden_states, forward_batch, residual ) else: - return self.forward_normal( + return self.forward_mode_mlp_all( positions, hidden_states, forward_batch, residual ) - def forward_normal( + def forward_mode_mlp_all( self, positions: torch.Tensor, hidden_states: torch.Tensor, @@ -1170,7 +1170,7 @@ def forward_normal( return hidden_states, residual - def forward_deepep( + def forward_mode_mlp_one( self, positions: torch.Tensor, hidden_states: torch.Tensor, From e5df579c89de8f33dd9b894e4d9e21d1a111991a Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 30 Mar 2025 09:28:29 +0800 Subject: [PATCH 0039/1089] more --- python/sglang/srt/server_args.py | 70 ++++++++++++++++---------------- 1 file changed, 36 insertions(+), 34 deletions(-) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index f318ed90b79..acf7fe6272b 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -237,6 +237,8 @@ def __post_init__(self): self.chunked_prefill_size = 8192 assert self.chunked_prefill_size % self.page_size == 0 + + assert self.moe_dense_tp_size in {1, None}, f"moe_dense_tp_size only support 1 and None currently" if self.enable_flashmla is True: logger.warning( @@ -380,8 +382,8 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.tokenizer_mode, choices=["auto", "slow"], help="Tokenizer mode. 'auto' will use the fast " - "tokenizer if available, and 'slow' will " - "always use the slow tokenizer.", + "tokenizer if available, and 'slow' will " + "always use the slow tokenizer.", ) parser.add_argument( "--skip-tokenizer-init", @@ -405,21 +407,21 @@ def add_cli_args(parser: argparse.ArgumentParser): "remote", ], help="The format of the model weights to load. " - '"auto" will try to load the weights in the safetensors format ' - "and fall back to the pytorch bin format if safetensors format " - "is not available. " - '"pt" will load the weights in the pytorch bin format. ' - '"safetensors" will load the weights in the safetensors format. ' - '"npcache" will load the weights in pytorch format and store ' - "a numpy cache to speed up the loading. " - '"dummy" will initialize the weights with random values, ' - "which is mainly for profiling." - '"gguf" will load the weights in the gguf format. ' - '"bitsandbytes" will load the weights using bitsandbytes ' - "quantization." - '"layered" loads weights layer by layer so that one can quantize a ' - "layer before loading another to make the peak memory envelope " - "smaller.", + '"auto" will try to load the weights in the safetensors format ' + "and fall back to the pytorch bin format if safetensors format " + "is not available. " + '"pt" will load the weights in the pytorch bin format. ' + '"safetensors" will load the weights in the safetensors format. ' + '"npcache" will load the weights in pytorch format and store ' + "a numpy cache to speed up the loading. " + '"dummy" will initialize the weights with random values, ' + "which is mainly for profiling." + '"gguf" will load the weights in the gguf format. ' + '"bitsandbytes" will load the weights using bitsandbytes ' + "quantization." + '"layered" loads weights layer by layer so that one can quantize a ' + "layer before loading another to make the peak memory envelope " + "smaller.", ) parser.add_argument( "--trust-remote-code", @@ -432,13 +434,13 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.dtype, choices=["auto", "half", "float16", "bfloat16", "float", "float32"], help="Data type for model weights and activations.\n\n" - '* "auto" will use FP16 precision for FP32 and FP16 models, and ' - "BF16 precision for BF16 models.\n" - '* "half" for FP16. Recommended for AWQ quantization.\n' - '* "float16" is the same as "half".\n' - '* "bfloat16" for a balance between precision and range.\n' - '* "float" is shorthand for FP32 precision.\n' - '* "float32" for FP32 precision.', + '* "auto" will use FP16 precision for FP32 and FP16 models, and ' + "BF16 precision for BF16 models.\n" + '* "half" for FP16. Recommended for AWQ quantization.\n' + '* "float16" is the same as "half".\n' + '* "bfloat16" for a balance between precision and range.\n' + '* "float" is shorthand for FP32 precision.\n' + '* "float32" for FP32 precision.', ) parser.add_argument( "--kv-cache-dtype", @@ -471,9 +473,9 @@ def add_cli_args(parser: argparse.ArgumentParser): type=nullable_str, default=None, help="Path to the JSON file containing the KV cache " - "scaling factors. This should generally be supplied, when " - "KV cache dtype is FP8. Otherwise, KV cache scaling factors " - "default to 1.0, which may cause accuracy issues. ", + "scaling factors. This should generally be supplied, when " + "KV cache dtype is FP8. Otherwise, KV cache scaling factors " + "default to 1.0, which may cause accuracy issues. ", ) parser.add_argument( "--context-length", @@ -515,8 +517,8 @@ def add_cli_args(parser: argparse.ArgumentParser): type=str, default=None, help="The specific model version to use. It can be a branch " - "name, a tag name, or a commit id. If unspecified, will use " - "the default version.", + "name, a tag name, or a commit id. If unspecified, will use " + "the default version.", ) # Memory and scheduling parser.add_argument( @@ -536,7 +538,7 @@ def add_cli_args(parser: argparse.ArgumentParser): type=int, default=ServerArgs.max_total_tokens, help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. " - "This option is typically used for development and debugging purposes.", + "This option is typically used for development and debugging purposes.", ) parser.add_argument( "--chunked-prefill-size", @@ -1001,7 +1003,7 @@ def add_cli_args(parser: argparse.ArgumentParser): "--triton-attention-reduce-in-fp32", action="store_true", help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16." - "This only affects Triton attention kernels.", + "This only affects Triton attention kernels.", ) parser.add_argument( "--triton-attention-num-kv-splits", @@ -1014,8 +1016,8 @@ def add_cli_args(parser: argparse.ArgumentParser): type=int, default=ServerArgs.num_continuous_decode_steps, help="Run multiple continuous decoding steps to reduce scheduling overhead. " - "This can potentially increase throughput but may also increase time-to-first-token latency. " - "The default value is 1, meaning only run one decoding step at a time.", + "This can potentially increase throughput but may also increase time-to-first-token latency. " + "The default value is 1, meaning only run one decoding step at a time.", ) parser.add_argument( "--delete-ckpt-after-loading", @@ -1074,7 +1076,7 @@ def add_cli_args(parser: argparse.ArgumentParser): type=str, required=False, help="Specify custom warmup functions (csv) to run before server starts eg. --warmups=warmup_name1,warmup_name2 " - "will run the functions `warmup_name1` and `warmup_name2` specified in warmup.py before the server starts listening for requests", + "will run the functions `warmup_name1` and `warmup_name2` specified in warmup.py before the server starts listening for requests", ) # Debug tensor dumps From 454e4ea88bfd75c1ee4f0536766bf9f280390663 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 30 Mar 2025 09:28:54 +0800 Subject: [PATCH 0040/1089] more --- python/sglang/srt/server_args.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index acf7fe6272b..3e965813085 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -237,7 +237,7 @@ def __post_init__(self): self.chunked_prefill_size = 8192 assert self.chunked_prefill_size % self.page_size == 0 - + assert self.moe_dense_tp_size in {1, None}, f"moe_dense_tp_size only support 1 and None currently" if self.enable_flashmla is True: @@ -1154,6 +1154,10 @@ def check_server_args(self): else: self.lora_paths[lora_path] = lora_path + @property + def enable_moe_dense_fully_dp(self): + return self.moe_dense_tp_size == 1 + def prepare_server_args(argv: List[str]) -> ServerArgs: """ From c6f6865df1bd00a571c3bbb27a537de6ce1b4142 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 30 Mar 2025 09:29:14 +0800 Subject: [PATCH 0041/1089] more --- python/sglang/srt/server_args.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 3e965813085..d3569d0bd12 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1154,10 +1154,6 @@ def check_server_args(self): else: self.lora_paths[lora_path] = lora_path - @property - def enable_moe_dense_fully_dp(self): - return self.moe_dense_tp_size == 1 - def prepare_server_args(argv: List[str]) -> ServerArgs: """ From eda287f14d107863bf318ed6dccd852b8dec4644 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 30 Mar 2025 09:29:42 +0800 Subject: [PATCH 0042/1089] more --- python/sglang/srt/models/deepseek_v2.py | 60 ++++++++++++------------- 1 file changed, 29 insertions(+), 31 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 3e117458eff..e68ed45e99d 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -21,9 +21,6 @@ import torch import torch.nn.functional as F -from torch import nn -from transformers import PretrainedConfig - from sglang.srt.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -75,6 +72,8 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import add_prefix, is_cuda, is_hip +from torch import nn +from transformers import PretrainedConfig _is_hip = is_hip() _is_cuda = is_cuda() @@ -374,7 +373,7 @@ def __init__( self.num_heads = num_heads assert num_heads % attn_tp_size == 0 self.num_local_heads = num_heads // attn_tp_size - self.scaling = self.qk_head_dim**-0.5 + self.scaling = self.qk_head_dim ** -0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings @@ -486,12 +485,12 @@ def forward( kv = self.kv_b_proj(kv_a)[0] kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - k_pe = latent_cache[:, :, self.kv_lora_rank :] + k_pe = latent_cache[:, :, self.kv_lora_rank:] q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - q[..., self.qk_nope_head_dim :] = q_pe + q[..., self.qk_nope_head_dim:] = q_pe k = torch.empty_like(q) k[..., : self.qk_nope_head_dim] = k_nope - k[..., self.qk_nope_head_dim :] = k_pe + k[..., self.qk_nope_head_dim:] = k_pe q = torch.nn.functional.pad(q, [0, 256 - self.qk_head_dim], value=0).view( -1, self.num_local_heads * 256 ) @@ -503,8 +502,8 @@ def forward( ) attn_output = self.attn(q, k, v, forward_batch) attn_output = attn_output.view(-1, self.num_local_heads, 256)[ - ..., : self.v_head_dim - ].reshape(-1, self.num_local_heads * self.v_head_dim) + ..., : self.v_head_dim + ].reshape(-1, self.num_local_heads * self.v_head_dim) output, _ = self.o_proj(attn_output) return output @@ -545,7 +544,7 @@ def __init__( self.num_heads = num_heads assert num_heads % attn_tp_size == 0 self.num_local_heads = num_heads // attn_tp_size - self.scaling = self.qk_head_dim**-0.5 + self.scaling = self.qk_head_dim ** -0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings @@ -731,16 +730,16 @@ def forward_normal( kv = self.kv_b_proj(kv_a)[0] kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) k_nope = kv[..., : self.qk_nope_head_dim] - v = kv[..., self.qk_nope_head_dim :] - k_pe = latent_cache[:, :, self.kv_lora_rank :] + v = kv[..., self.qk_nope_head_dim:] + k_pe = latent_cache[:, :, self.kv_lora_rank:] q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - q[..., self.qk_nope_head_dim :] = q_pe + q[..., self.qk_nope_head_dim:] = q_pe k = torch.empty_like(q) k[..., : self.qk_nope_head_dim] = k_nope - k[..., self.qk_nope_head_dim :] = k_pe + k[..., self.qk_nope_head_dim:] = k_pe latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1) - latent_cache[:, :, self.kv_lora_rank :] = k_pe + latent_cache[:, :, self.kv_lora_rank:] = k_pe # Save latent cache forward_batch.token_to_kv_pool.set_kv_buffer( @@ -793,11 +792,11 @@ def forward_absorb( v_input = self.kv_a_layernorm(v_input.contiguous()).unsqueeze(1) k_input = latent_cache.unsqueeze(1) k_input[..., : self.kv_lora_rank] = v_input - k_pe = k_input[..., self.kv_lora_rank :] + k_pe = k_input[..., self.kv_lora_rank:] q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - q_input[..., self.kv_lora_rank :] = q_pe - k_input[..., self.kv_lora_rank :] = k_pe + q_input[..., self.kv_lora_rank:] = q_pe + k_input[..., self.kv_lora_rank:] = k_pe attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch) attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank) @@ -873,15 +872,15 @@ def forward_absorb_fused_mla_rope( k_input[..., : self.kv_lora_rank] = v_input if not enable_rope_fusion: - k_pe = k_input[..., self.kv_lora_rank :] + k_pe = k_input[..., self.kv_lora_rank:] q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - q_input[..., self.kv_lora_rank :] = q_pe - k_input[..., self.kv_lora_rank :] = k_pe + q_input[..., self.kv_lora_rank:] = q_pe + k_input[..., self.kv_lora_rank:] = k_pe k_pe_output = None else: - k_pe_output = torch.empty_like(k_input[..., self.kv_lora_rank :]) + k_pe_output = torch.empty_like(k_input[..., self.kv_lora_rank:]) - q_input[..., self.kv_lora_rank :] = q_pe + q_input[..., self.kv_lora_rank:] = q_pe # attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch) # Use Fused ROPE with use_rope=OFF. @@ -938,7 +937,7 @@ def forward_absorb_fused_mla_rope( ) if enable_rope_fusion: - k_input[..., self.kv_lora_rank :] = k_pe_output + k_input[..., self.kv_lora_rank:] = k_pe_output forward_batch.token_to_kv_pool.set_kv_buffer( self.attn_mqa, forward_batch.out_cache_loc, k_input, None ) @@ -1048,19 +1047,18 @@ def is_sparse_layer(l: int): ) self.is_sparse = True else: - moe_dense_tp_size = global_server_args_dict["moe_dense_tp_size"] + if global_server_args_dict["moe_dense_tp_size"] == 1: + mlp_tp_rank, mlp_tp_size = 0, 1 + else: + mlp_tp_rank, mlp_tp_size = None, None self.mlp = DeepseekV2MLP( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, prefix=add_prefix("mlp", prefix), - tp_rank=( - (get_tensor_model_parallel_rank() % moe_dense_tp_size) - if moe_dense_tp_size - else None - ), - tp_size=moe_dense_tp_size, + tp_rank=mlp_tp_rank, + tp_size=mlp_tp_size, ) self.is_sparse = False From 49e252cbd660d3f70a1e7b1baa26b0fe943db412 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 30 Mar 2025 09:30:46 +0800 Subject: [PATCH 0043/1089] more --- python/sglang/srt/models/deepseek_v2.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index e68ed45e99d..238b970cf45 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -17,6 +17,7 @@ """Inference-only DeepseekV2 model.""" import os +from enum import Enum, auto from typing import Any, Dict, Iterable, Optional, Tuple import torch @@ -969,6 +970,11 @@ def forward_absorb_fused_mla_rope( return output +class _DecoderLayerForwardMode(Enum): + MLP_ONE = auto() + MLP_ALL = auto() + + class DeepseekV2DecoderLayer(nn.Module): def __init__( @@ -1073,6 +1079,13 @@ def is_sparse_layer(l: int): config.hidden_size, eps=config.rms_norm_eps ) + @staticmethod + def _compute_mode(): + if global_server_args_dict["enable_deepep_moe"] and self.is_sparse: + return _DecoderLayerForwardMode.MLP_ONE + else: + return _DecoderLayerForwardMode.MLP_ALL + def forward( self, positions: torch.Tensor, From 4c8f17955f1071259a3f25073cc230ba4b410107 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 30 Mar 2025 09:31:35 +0800 Subject: [PATCH 0044/1089] more --- python/sglang/srt/models/deepseek_v2.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 238b970cf45..4f4739fbb34 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1093,14 +1093,17 @@ def forward( forward_batch: ForwardBatch, residual: Optional[torch.Tensor], ) -> torch.Tensor: - if global_server_args_dict["enable_deepep_moe"] and self.is_sparse: + mode = self._compute_mode() + if mode == _DecoderLayerForwardMode.MLP_ONE: return self.forward_mode_mlp_one( positions, hidden_states, forward_batch, residual ) - else: + elif mode == _DecoderLayerForwardMode.MLP_ALL: return self.forward_mode_mlp_all( positions, hidden_states, forward_batch, residual ) + else: + raise NotImplementedError def forward_mode_mlp_all( self, From 94675b444831acfddccbf73483073f61a14d49ba Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 30 Mar 2025 09:33:42 +0800 Subject: [PATCH 0045/1089] more --- python/sglang/srt/models/deepseek_v2.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 4f4739fbb34..fd184d6763f 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -985,14 +985,6 @@ def __init__( is_nextn: bool = False, prefix: str = "", ) -> None: - - def is_sparse_layer(l: int): - return ( - config.n_routed_experts is not None - and l >= config.first_k_dense_replace - and l % config.moe_layer_freq == 0 - ) - super().__init__() self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) @@ -1045,7 +1037,7 @@ def is_sparse_layer(l: int): prefix=add_prefix("self_attn", prefix), ) - if is_nextn or is_sparse_layer(layer_id): + if is_nextn or self._is_sparse_layer(config, layer_id): self.mlp = DeepseekV2MoE( config=config, quant_config=quant_config, @@ -1069,7 +1061,7 @@ def is_sparse_layer(l: int): self.is_sparse = False self.input_is_scattered = ( - is_sparse_layer(layer_id - 1) + self._is_sparse_layer(config, layer_id - 1) and global_server_args_dict["enable_deepep_moe"] ) self.is_last_layer = self.layer_id == config.num_hidden_layers - 1 @@ -1079,6 +1071,14 @@ def is_sparse_layer(l: int): config.hidden_size, eps=config.rms_norm_eps ) + @staticmethod + def _is_sparse_layer(config: PretrainedConfig, layer_id: int): + return ( + config.n_routed_experts is not None + and layer_id >= config.first_k_dense_replace + and layer_id % config.moe_layer_freq == 0 + ) + @staticmethod def _compute_mode(): if global_server_args_dict["enable_deepep_moe"] and self.is_sparse: From 45cea6d8c34d04374339340d1ef4efa8e77605d4 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 30 Mar 2025 09:37:14 +0800 Subject: [PATCH 0046/1089] more --- python/sglang/srt/models/deepseek_v2.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index fd184d6763f..5f814164a7f 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1037,7 +1037,7 @@ def __init__( prefix=add_prefix("self_attn", prefix), ) - if is_nextn or self._is_sparse_layer(config, layer_id): + if self._is_sparse_layer(config, layer_id, is_nextn=is_nextn): self.mlp = DeepseekV2MoE( config=config, quant_config=quant_config, @@ -1061,7 +1061,7 @@ def __init__( self.is_sparse = False self.input_is_scattered = ( - self._is_sparse_layer(config, layer_id - 1) + self._is_sparse_layer(config, layer_id - 1, is_nextn=False) and global_server_args_dict["enable_deepep_moe"] ) self.is_last_layer = self.layer_id == config.num_hidden_layers - 1 @@ -1072,8 +1072,8 @@ def __init__( ) @staticmethod - def _is_sparse_layer(config: PretrainedConfig, layer_id: int): - return ( + def _is_sparse_layer(config: PretrainedConfig, layer_id: int, is_nextn: bool): + return is_nextn or ( config.n_routed_experts is not None and layer_id >= config.first_k_dense_replace and layer_id % config.moe_layer_freq == 0 From 2a8601e1194bda83a211c2e58ed99b1a07e13234 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 30 Mar 2025 09:37:44 +0800 Subject: [PATCH 0047/1089] more --- python/sglang/srt/models/deepseek_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 5f814164a7f..fdab279cc6a 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1061,7 +1061,7 @@ def __init__( self.is_sparse = False self.input_is_scattered = ( - self._is_sparse_layer(config, layer_id - 1, is_nextn=False) + self._is_sparse_layer(config, layer_id=layer_id - 1, is_nextn=False) and global_server_args_dict["enable_deepep_moe"] ) self.is_last_layer = self.layer_id == config.num_hidden_layers - 1 From 3be239941263e1a3538eca1861e7437d39e6e26b Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 30 Mar 2025 09:38:03 +0800 Subject: [PATCH 0048/1089] more --- python/sglang/srt/models/deepseek_v2.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index fdab279cc6a..0a500c38157 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1037,13 +1037,13 @@ def __init__( prefix=add_prefix("self_attn", prefix), ) - if self._is_sparse_layer(config, layer_id, is_nextn=is_nextn): + self.is_sparse = self._is_sparse_layer(config, layer_id, is_nextn=is_nextn) + if self.is_sparse: self.mlp = DeepseekV2MoE( config=config, quant_config=quant_config, prefix=add_prefix("mlp", prefix), ) - self.is_sparse = True else: if global_server_args_dict["moe_dense_tp_size"] == 1: mlp_tp_rank, mlp_tp_size = 0, 1 @@ -1058,7 +1058,6 @@ def __init__( tp_rank=mlp_tp_rank, tp_size=mlp_tp_size, ) - self.is_sparse = False self.input_is_scattered = ( self._is_sparse_layer(config, layer_id=layer_id - 1, is_nextn=False) From aaff5d6b6253bd0b57fbf2fb4dc6048096bb7cdf Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 30 Mar 2025 09:38:27 +0800 Subject: [PATCH 0049/1089] more --- python/sglang/srt/models/deepseek_v2.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 0a500c38157..f2df29be35e 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1037,7 +1037,7 @@ def __init__( prefix=add_prefix("self_attn", prefix), ) - self.is_sparse = self._is_sparse_layer(config, layer_id, is_nextn=is_nextn) + self.is_sparse = self._compute_is_sparse(config, layer_id, is_nextn=is_nextn) if self.is_sparse: self.mlp = DeepseekV2MoE( config=config, @@ -1060,7 +1060,7 @@ def __init__( ) self.input_is_scattered = ( - self._is_sparse_layer(config, layer_id=layer_id - 1, is_nextn=False) + self._compute_is_sparse(config, layer_id=layer_id - 1, is_nextn=False) and global_server_args_dict["enable_deepep_moe"] ) self.is_last_layer = self.layer_id == config.num_hidden_layers - 1 @@ -1071,7 +1071,7 @@ def __init__( ) @staticmethod - def _is_sparse_layer(config: PretrainedConfig, layer_id: int, is_nextn: bool): + def _compute_is_sparse(config: PretrainedConfig, layer_id: int, is_nextn: bool): return is_nextn or ( config.n_routed_experts is not None and layer_id >= config.first_k_dense_replace From 2032d266e43556ec731cbf713ffed4b7c02a9ed3 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 30 Mar 2025 09:38:45 +0800 Subject: [PATCH 0050/1089] more --- python/sglang/srt/models/deepseek_v2.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index f2df29be35e..176e16ff7a0 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -970,7 +970,7 @@ def forward_absorb_fused_mla_rope( return output -class _DecoderLayerForwardMode(Enum): +class _DecoderLayerExecutionMode(Enum): MLP_ONE = auto() MLP_ALL = auto() @@ -1079,11 +1079,11 @@ def _compute_is_sparse(config: PretrainedConfig, layer_id: int, is_nextn: bool): ) @staticmethod - def _compute_mode(): + def _compute_execution_mode(): if global_server_args_dict["enable_deepep_moe"] and self.is_sparse: - return _DecoderLayerForwardMode.MLP_ONE + return _DecoderLayerExecutionMode.MLP_ONE else: - return _DecoderLayerForwardMode.MLP_ALL + return _DecoderLayerExecutionMode.MLP_ALL def forward( self, @@ -1092,12 +1092,12 @@ def forward( forward_batch: ForwardBatch, residual: Optional[torch.Tensor], ) -> torch.Tensor: - mode = self._compute_mode() - if mode == _DecoderLayerForwardMode.MLP_ONE: + mode = self._compute_execution_mode() + if mode == _DecoderLayerExecutionMode.MLP_ONE: return self.forward_mode_mlp_one( positions, hidden_states, forward_batch, residual ) - elif mode == _DecoderLayerForwardMode.MLP_ALL: + elif mode == _DecoderLayerExecutionMode.MLP_ALL: return self.forward_mode_mlp_all( positions, hidden_states, forward_batch, residual ) From a0e12f50315af7703a07dda08a42110d9a2201a7 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 30 Mar 2025 09:39:35 +0800 Subject: [PATCH 0051/1089] more --- python/sglang/srt/models/deepseek_v2.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 176e16ff7a0..2aee373b856 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1038,6 +1038,8 @@ def __init__( ) self.is_sparse = self._compute_is_sparse(config, layer_id, is_nextn=is_nextn) + self.execution_mode = self._compute_execution_mode(is_sparse=self.is_sparse) + if self.is_sparse: self.mlp = DeepseekV2MoE( config=config, @@ -1079,8 +1081,8 @@ def _compute_is_sparse(config: PretrainedConfig, layer_id: int, is_nextn: bool): ) @staticmethod - def _compute_execution_mode(): - if global_server_args_dict["enable_deepep_moe"] and self.is_sparse: + def _compute_execution_mode(is_sparse: bool): + if global_server_args_dict["enable_deepep_moe"] and is_sparse: return _DecoderLayerExecutionMode.MLP_ONE else: return _DecoderLayerExecutionMode.MLP_ALL @@ -1092,12 +1094,11 @@ def forward( forward_batch: ForwardBatch, residual: Optional[torch.Tensor], ) -> torch.Tensor: - mode = self._compute_execution_mode() - if mode == _DecoderLayerExecutionMode.MLP_ONE: + if self.execution_mode == _DecoderLayerExecutionMode.MLP_ONE: return self.forward_mode_mlp_one( positions, hidden_states, forward_batch, residual ) - elif mode == _DecoderLayerExecutionMode.MLP_ALL: + elif self.execution_mode == _DecoderLayerExecutionMode.MLP_ALL: return self.forward_mode_mlp_all( positions, hidden_states, forward_batch, residual ) From 2e460a497b13e802f8c0f36f568a225d2999cb4c Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 30 Mar 2025 09:41:24 +0800 Subject: [PATCH 0052/1089] more --- python/sglang/srt/models/deepseek_v2.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 2aee373b856..0a0bd68c215 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -975,6 +975,12 @@ class _DecoderLayerExecutionMode(Enum): MLP_ALL = auto() +@dataclass +class _DecoderLayerInfo: + is_sparse: bool + execution_mode: _DecoderLayerExecutionMode + + class DeepseekV2DecoderLayer(nn.Module): def __init__( @@ -1073,19 +1079,18 @@ def __init__( ) @staticmethod - def _compute_is_sparse(config: PretrainedConfig, layer_id: int, is_nextn: bool): - return is_nextn or ( + def _compute_info(config: PretrainedConfig, layer_id: int, is_nextn: bool): + is_sparse = is_nextn or ( config.n_routed_experts is not None and layer_id >= config.first_k_dense_replace and layer_id % config.moe_layer_freq == 0 ) - - @staticmethod - def _compute_execution_mode(is_sparse: bool): - if global_server_args_dict["enable_deepep_moe"] and is_sparse: - return _DecoderLayerExecutionMode.MLP_ONE - else: - return _DecoderLayerExecutionMode.MLP_ALL + execution_mode = ( + _DecoderLayerExecutionMode.MLP_ONE + if global_server_args_dict["enable_deepep_moe"] and is_sparse + else _DecoderLayerExecutionMode.MLP_ALL + ) + return _DecoderLayerInfo(is_sparse=is_sparse, execution_mode=execution_mode) def forward( self, From d1edc5bc319e1442a711b41dcf5eb3091f6763f5 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 30 Mar 2025 09:41:45 +0800 Subject: [PATCH 0053/1089] more --- python/sglang/srt/models/deepseek_v2.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 0a0bd68c215..695df2d3b39 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -17,6 +17,7 @@ """Inference-only DeepseekV2 model.""" import os +from dataclasses import dataclass from enum import Enum, auto from typing import Any, Dict, Iterable, Optional, Tuple @@ -1043,10 +1044,9 @@ def __init__( prefix=add_prefix("self_attn", prefix), ) - self.is_sparse = self._compute_is_sparse(config, layer_id, is_nextn=is_nextn) - self.execution_mode = self._compute_execution_mode(is_sparse=self.is_sparse) + self.info = self._compute_info(config, layer_id=layer_id, is_nextn=is_nextn) - if self.is_sparse: + if self.info.is_sparse: self.mlp = DeepseekV2MoE( config=config, quant_config=quant_config, From 23c1e62d0898334ceb4e5e9db347e20340389291 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 30 Mar 2025 09:41:57 +0800 Subject: [PATCH 0054/1089] more --- python/sglang/srt/models/deepseek_v2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 695df2d3b39..79e42caca13 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1099,11 +1099,11 @@ def forward( forward_batch: ForwardBatch, residual: Optional[torch.Tensor], ) -> torch.Tensor: - if self.execution_mode == _DecoderLayerExecutionMode.MLP_ONE: + if self.info.execution_mode == _DecoderLayerExecutionMode.MLP_ONE: return self.forward_mode_mlp_one( positions, hidden_states, forward_batch, residual ) - elif self.execution_mode == _DecoderLayerExecutionMode.MLP_ALL: + elif self.info.execution_mode == _DecoderLayerExecutionMode.MLP_ALL: return self.forward_mode_mlp_all( positions, hidden_states, forward_batch, residual ) From b8a9985d2bf616665767a104d5781dc999d08b0b Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 30 Mar 2025 09:42:17 +0800 Subject: [PATCH 0055/1089] more --- python/sglang/srt/models/deepseek_v2.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 79e42caca13..50b51195f80 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1067,6 +1067,7 @@ def __init__( tp_size=mlp_tp_size, ) + previous_layer_info = self._compute_info(config, layer_id=layer_id - 1, is_nextn=False) self.input_is_scattered = ( self._compute_is_sparse(config, layer_id=layer_id - 1, is_nextn=False) and global_server_args_dict["enable_deepep_moe"] From 2d6e43e8fc236afb56d79dbc46aad52604b2b8c4 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 30 Mar 2025 09:42:32 +0800 Subject: [PATCH 0056/1089] more --- python/sglang/srt/models/deepseek_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 50b51195f80..cfaebae44ed 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1069,7 +1069,7 @@ def __init__( previous_layer_info = self._compute_info(config, layer_id=layer_id - 1, is_nextn=False) self.input_is_scattered = ( - self._compute_is_sparse(config, layer_id=layer_id - 1, is_nextn=False) + previous_layer_info.is_sparse and global_server_args_dict["enable_deepep_moe"] ) self.is_last_layer = self.layer_id == config.num_hidden_layers - 1 From 323ca8aba4d6ef78baecdf40e6916c5f30fa08f2 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 30 Mar 2025 09:43:00 +0800 Subject: [PATCH 0057/1089] more --- python/sglang/srt/models/deepseek_v2.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index cfaebae44ed..7a6fdc27129 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1068,10 +1068,7 @@ def __init__( ) previous_layer_info = self._compute_info(config, layer_id=layer_id - 1, is_nextn=False) - self.input_is_scattered = ( - previous_layer_info.is_sparse - and global_server_args_dict["enable_deepep_moe"] - ) + self.input_is_scattered = previous_layer_info.execution_mode == _DecoderLayerExecutionMode.MLP_ONE self.is_last_layer = self.layer_id == config.num_hidden_layers - 1 self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) From 4579304070a06d86f6640df2226113ca3d496d66 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 30 Mar 2025 09:43:08 +0800 Subject: [PATCH 0058/1089] more --- python/sglang/srt/models/deepseek_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 7a6fdc27129..9bb4e90e01f 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1045,6 +1045,7 @@ def __init__( ) self.info = self._compute_info(config, layer_id=layer_id, is_nextn=is_nextn) + previous_layer_info = self._compute_info(config, layer_id=layer_id - 1, is_nextn=False) if self.info.is_sparse: self.mlp = DeepseekV2MoE( @@ -1067,7 +1068,6 @@ def __init__( tp_size=mlp_tp_size, ) - previous_layer_info = self._compute_info(config, layer_id=layer_id - 1, is_nextn=False) self.input_is_scattered = previous_layer_info.execution_mode == _DecoderLayerExecutionMode.MLP_ONE self.is_last_layer = self.layer_id == config.num_hidden_layers - 1 From ab94c7b1411ee70d9ec74412c69f21797518e2ca Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 30 Mar 2025 09:44:49 +0800 Subject: [PATCH 0059/1089] more --- python/sglang/srt/models/deepseek_v2.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 9bb4e90e01f..060ba5fa822 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1054,7 +1054,7 @@ def __init__( prefix=add_prefix("mlp", prefix), ) else: - if global_server_args_dict["moe_dense_tp_size"] == 1: + if self._enable_moe_dense_fully_dp(): mlp_tp_rank, mlp_tp_size = 0, 1 else: mlp_tp_rank, mlp_tp_size = None, None @@ -1076,6 +1076,10 @@ def __init__( config.hidden_size, eps=config.rms_norm_eps ) + @staticmethod + def _enable_moe_dense_fully_dp(): + return global_server_args_dict["moe_dense_tp_size"] == 1 + @staticmethod def _compute_info(config: PretrainedConfig, layer_id: int, is_nextn: bool): is_sparse = is_nextn or ( @@ -1085,7 +1089,8 @@ def _compute_info(config: PretrainedConfig, layer_id: int, is_nextn: bool): ) execution_mode = ( _DecoderLayerExecutionMode.MLP_ONE - if global_server_args_dict["enable_deepep_moe"] and is_sparse + if (global_server_args_dict["enable_deepep_moe"] and is_sparse) or \ + (DeepseekV2DecoderLayer._enable_moe_dense_fully_dp() and not is_sparse) else _DecoderLayerExecutionMode.MLP_ALL ) return _DecoderLayerInfo(is_sparse=is_sparse, execution_mode=execution_mode) From 9d6b1349c6cad83664a287fdb37996598352517b Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 30 Mar 2025 09:49:14 +0800 Subject: [PATCH 0060/1089] more --- python/sglang/srt/models/deepseek_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 060ba5fa822..1454c0a001e 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -127,7 +127,7 @@ def __init__( ) self.act_fn = SiluAndMul() - def forward(self, x): + def forward(self, x, forward_mode: Optional[ForwardMode] = None): gate_up, _ = self.gate_up_proj(x) x = self.act_fn(gate_up) x, _ = self.down_proj(x) From 8c92da45ab9834b8b90d5b7a8e4997491dc522b7 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 30 Mar 2025 09:58:36 +0800 Subject: [PATCH 0061/1089] more --- python/sglang/srt/models/deepseek_v2.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 1454c0a001e..da1ffb510c2 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1248,7 +1248,9 @@ def forward_mode_mlp_one( hidden_states, residual = self.post_attention_layernorm( hidden_states, residual ) - hidden_states = self.mlp(hidden_states, forward_batch.forward_mode) + + if not (self._enable_moe_dense_fully_dp() and hidden_states.shape[0] == 0): + hidden_states = self.mlp(hidden_states, forward_batch.forward_mode) if self.is_last_layer and self.attn_tp_size != 1: hidden_states, local_hidden_states = ( From 514a6b54f286cbbb7cb9b81a82aca47bcd30f941 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 30 Mar 2025 10:07:43 +0800 Subject: [PATCH 0062/1089] more --- python/sglang/srt/models/deepseek_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index da1ffb510c2..5336d05de1b 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1249,7 +1249,7 @@ def forward_mode_mlp_one( hidden_states, residual ) - if not (self._enable_moe_dense_fully_dp() and hidden_states.shape[0] == 0): + if not (self._enable_moe_dense_fully_dp() and (not self.info.is_sparse) and hidden_states.shape[0] == 0): hidden_states = self.mlp(hidden_states, forward_batch.forward_mode) if self.is_last_layer and self.attn_tp_size != 1: From 1c384190bed05a4751c0b00946d674746679d950 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 30 Mar 2025 10:57:21 +0800 Subject: [PATCH 0063/1089] fmt --- python/sglang/srt/models/deepseek_v2.py | 65 ++++++++++++---------- python/sglang/srt/server_args.py | 73 +++++++++++++------------ 2 files changed, 75 insertions(+), 63 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 5336d05de1b..b87d55f776a 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -23,6 +23,9 @@ import torch import torch.nn.functional as F +from torch import nn +from transformers import PretrainedConfig + from sglang.srt.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -74,8 +77,6 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import add_prefix, is_cuda, is_hip -from torch import nn -from transformers import PretrainedConfig _is_hip = is_hip() _is_cuda = is_cuda() @@ -375,7 +376,7 @@ def __init__( self.num_heads = num_heads assert num_heads % attn_tp_size == 0 self.num_local_heads = num_heads // attn_tp_size - self.scaling = self.qk_head_dim ** -0.5 + self.scaling = self.qk_head_dim**-0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings @@ -487,12 +488,12 @@ def forward( kv = self.kv_b_proj(kv_a)[0] kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - k_pe = latent_cache[:, :, self.kv_lora_rank:] + k_pe = latent_cache[:, :, self.kv_lora_rank :] q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - q[..., self.qk_nope_head_dim:] = q_pe + q[..., self.qk_nope_head_dim :] = q_pe k = torch.empty_like(q) k[..., : self.qk_nope_head_dim] = k_nope - k[..., self.qk_nope_head_dim:] = k_pe + k[..., self.qk_nope_head_dim :] = k_pe q = torch.nn.functional.pad(q, [0, 256 - self.qk_head_dim], value=0).view( -1, self.num_local_heads * 256 ) @@ -504,8 +505,8 @@ def forward( ) attn_output = self.attn(q, k, v, forward_batch) attn_output = attn_output.view(-1, self.num_local_heads, 256)[ - ..., : self.v_head_dim - ].reshape(-1, self.num_local_heads * self.v_head_dim) + ..., : self.v_head_dim + ].reshape(-1, self.num_local_heads * self.v_head_dim) output, _ = self.o_proj(attn_output) return output @@ -546,7 +547,7 @@ def __init__( self.num_heads = num_heads assert num_heads % attn_tp_size == 0 self.num_local_heads = num_heads // attn_tp_size - self.scaling = self.qk_head_dim ** -0.5 + self.scaling = self.qk_head_dim**-0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings @@ -732,16 +733,16 @@ def forward_normal( kv = self.kv_b_proj(kv_a)[0] kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) k_nope = kv[..., : self.qk_nope_head_dim] - v = kv[..., self.qk_nope_head_dim:] - k_pe = latent_cache[:, :, self.kv_lora_rank:] + v = kv[..., self.qk_nope_head_dim :] + k_pe = latent_cache[:, :, self.kv_lora_rank :] q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - q[..., self.qk_nope_head_dim:] = q_pe + q[..., self.qk_nope_head_dim :] = q_pe k = torch.empty_like(q) k[..., : self.qk_nope_head_dim] = k_nope - k[..., self.qk_nope_head_dim:] = k_pe + k[..., self.qk_nope_head_dim :] = k_pe latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1) - latent_cache[:, :, self.kv_lora_rank:] = k_pe + latent_cache[:, :, self.kv_lora_rank :] = k_pe # Save latent cache forward_batch.token_to_kv_pool.set_kv_buffer( @@ -794,11 +795,11 @@ def forward_absorb( v_input = self.kv_a_layernorm(v_input.contiguous()).unsqueeze(1) k_input = latent_cache.unsqueeze(1) k_input[..., : self.kv_lora_rank] = v_input - k_pe = k_input[..., self.kv_lora_rank:] + k_pe = k_input[..., self.kv_lora_rank :] q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - q_input[..., self.kv_lora_rank:] = q_pe - k_input[..., self.kv_lora_rank:] = k_pe + q_input[..., self.kv_lora_rank :] = q_pe + k_input[..., self.kv_lora_rank :] = k_pe attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch) attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank) @@ -874,15 +875,15 @@ def forward_absorb_fused_mla_rope( k_input[..., : self.kv_lora_rank] = v_input if not enable_rope_fusion: - k_pe = k_input[..., self.kv_lora_rank:] + k_pe = k_input[..., self.kv_lora_rank :] q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - q_input[..., self.kv_lora_rank:] = q_pe - k_input[..., self.kv_lora_rank:] = k_pe + q_input[..., self.kv_lora_rank :] = q_pe + k_input[..., self.kv_lora_rank :] = k_pe k_pe_output = None else: - k_pe_output = torch.empty_like(k_input[..., self.kv_lora_rank:]) + k_pe_output = torch.empty_like(k_input[..., self.kv_lora_rank :]) - q_input[..., self.kv_lora_rank:] = q_pe + q_input[..., self.kv_lora_rank :] = q_pe # attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch) # Use Fused ROPE with use_rope=OFF. @@ -939,7 +940,7 @@ def forward_absorb_fused_mla_rope( ) if enable_rope_fusion: - k_input[..., self.kv_lora_rank:] = k_pe_output + k_input[..., self.kv_lora_rank :] = k_pe_output forward_batch.token_to_kv_pool.set_kv_buffer( self.attn_mqa, forward_batch.out_cache_loc, k_input, None ) @@ -1045,7 +1046,9 @@ def __init__( ) self.info = self._compute_info(config, layer_id=layer_id, is_nextn=is_nextn) - previous_layer_info = self._compute_info(config, layer_id=layer_id - 1, is_nextn=False) + previous_layer_info = self._compute_info( + config, layer_id=layer_id - 1, is_nextn=False + ) if self.info.is_sparse: self.mlp = DeepseekV2MoE( @@ -1068,7 +1071,9 @@ def __init__( tp_size=mlp_tp_size, ) - self.input_is_scattered = previous_layer_info.execution_mode == _DecoderLayerExecutionMode.MLP_ONE + self.input_is_scattered = ( + previous_layer_info.execution_mode == _DecoderLayerExecutionMode.MLP_ONE + ) self.is_last_layer = self.layer_id == config.num_hidden_layers - 1 self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -1089,8 +1094,8 @@ def _compute_info(config: PretrainedConfig, layer_id: int, is_nextn: bool): ) execution_mode = ( _DecoderLayerExecutionMode.MLP_ONE - if (global_server_args_dict["enable_deepep_moe"] and is_sparse) or \ - (DeepseekV2DecoderLayer._enable_moe_dense_fully_dp() and not is_sparse) + if (global_server_args_dict["enable_deepep_moe"] and is_sparse) + or (DeepseekV2DecoderLayer._enable_moe_dense_fully_dp() and not is_sparse) else _DecoderLayerExecutionMode.MLP_ALL ) return _DecoderLayerInfo(is_sparse=is_sparse, execution_mode=execution_mode) @@ -1249,7 +1254,11 @@ def forward_mode_mlp_one( hidden_states, residual ) - if not (self._enable_moe_dense_fully_dp() and (not self.info.is_sparse) and hidden_states.shape[0] == 0): + if not ( + self._enable_moe_dense_fully_dp() + and (not self.info.is_sparse) + and hidden_states.shape[0] == 0 + ): hidden_states = self.mlp(hidden_states, forward_batch.forward_mode) if self.is_last_layer and self.attn_tp_size != 1: diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index d3569d0bd12..e55fbe4072c 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -238,7 +238,10 @@ def __post_init__(self): assert self.chunked_prefill_size % self.page_size == 0 - assert self.moe_dense_tp_size in {1, None}, f"moe_dense_tp_size only support 1 and None currently" + assert self.moe_dense_tp_size in { + 1, + None, + }, f"moe_dense_tp_size only support 1 and None currently" if self.enable_flashmla is True: logger.warning( @@ -382,8 +385,8 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.tokenizer_mode, choices=["auto", "slow"], help="Tokenizer mode. 'auto' will use the fast " - "tokenizer if available, and 'slow' will " - "always use the slow tokenizer.", + "tokenizer if available, and 'slow' will " + "always use the slow tokenizer.", ) parser.add_argument( "--skip-tokenizer-init", @@ -407,21 +410,21 @@ def add_cli_args(parser: argparse.ArgumentParser): "remote", ], help="The format of the model weights to load. " - '"auto" will try to load the weights in the safetensors format ' - "and fall back to the pytorch bin format if safetensors format " - "is not available. " - '"pt" will load the weights in the pytorch bin format. ' - '"safetensors" will load the weights in the safetensors format. ' - '"npcache" will load the weights in pytorch format and store ' - "a numpy cache to speed up the loading. " - '"dummy" will initialize the weights with random values, ' - "which is mainly for profiling." - '"gguf" will load the weights in the gguf format. ' - '"bitsandbytes" will load the weights using bitsandbytes ' - "quantization." - '"layered" loads weights layer by layer so that one can quantize a ' - "layer before loading another to make the peak memory envelope " - "smaller.", + '"auto" will try to load the weights in the safetensors format ' + "and fall back to the pytorch bin format if safetensors format " + "is not available. " + '"pt" will load the weights in the pytorch bin format. ' + '"safetensors" will load the weights in the safetensors format. ' + '"npcache" will load the weights in pytorch format and store ' + "a numpy cache to speed up the loading. " + '"dummy" will initialize the weights with random values, ' + "which is mainly for profiling." + '"gguf" will load the weights in the gguf format. ' + '"bitsandbytes" will load the weights using bitsandbytes ' + "quantization." + '"layered" loads weights layer by layer so that one can quantize a ' + "layer before loading another to make the peak memory envelope " + "smaller.", ) parser.add_argument( "--trust-remote-code", @@ -434,13 +437,13 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.dtype, choices=["auto", "half", "float16", "bfloat16", "float", "float32"], help="Data type for model weights and activations.\n\n" - '* "auto" will use FP16 precision for FP32 and FP16 models, and ' - "BF16 precision for BF16 models.\n" - '* "half" for FP16. Recommended for AWQ quantization.\n' - '* "float16" is the same as "half".\n' - '* "bfloat16" for a balance between precision and range.\n' - '* "float" is shorthand for FP32 precision.\n' - '* "float32" for FP32 precision.', + '* "auto" will use FP16 precision for FP32 and FP16 models, and ' + "BF16 precision for BF16 models.\n" + '* "half" for FP16. Recommended for AWQ quantization.\n' + '* "float16" is the same as "half".\n' + '* "bfloat16" for a balance between precision and range.\n' + '* "float" is shorthand for FP32 precision.\n' + '* "float32" for FP32 precision.', ) parser.add_argument( "--kv-cache-dtype", @@ -473,9 +476,9 @@ def add_cli_args(parser: argparse.ArgumentParser): type=nullable_str, default=None, help="Path to the JSON file containing the KV cache " - "scaling factors. This should generally be supplied, when " - "KV cache dtype is FP8. Otherwise, KV cache scaling factors " - "default to 1.0, which may cause accuracy issues. ", + "scaling factors. This should generally be supplied, when " + "KV cache dtype is FP8. Otherwise, KV cache scaling factors " + "default to 1.0, which may cause accuracy issues. ", ) parser.add_argument( "--context-length", @@ -517,8 +520,8 @@ def add_cli_args(parser: argparse.ArgumentParser): type=str, default=None, help="The specific model version to use. It can be a branch " - "name, a tag name, or a commit id. If unspecified, will use " - "the default version.", + "name, a tag name, or a commit id. If unspecified, will use " + "the default version.", ) # Memory and scheduling parser.add_argument( @@ -538,7 +541,7 @@ def add_cli_args(parser: argparse.ArgumentParser): type=int, default=ServerArgs.max_total_tokens, help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. " - "This option is typically used for development and debugging purposes.", + "This option is typically used for development and debugging purposes.", ) parser.add_argument( "--chunked-prefill-size", @@ -1003,7 +1006,7 @@ def add_cli_args(parser: argparse.ArgumentParser): "--triton-attention-reduce-in-fp32", action="store_true", help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16." - "This only affects Triton attention kernels.", + "This only affects Triton attention kernels.", ) parser.add_argument( "--triton-attention-num-kv-splits", @@ -1016,8 +1019,8 @@ def add_cli_args(parser: argparse.ArgumentParser): type=int, default=ServerArgs.num_continuous_decode_steps, help="Run multiple continuous decoding steps to reduce scheduling overhead. " - "This can potentially increase throughput but may also increase time-to-first-token latency. " - "The default value is 1, meaning only run one decoding step at a time.", + "This can potentially increase throughput but may also increase time-to-first-token latency. " + "The default value is 1, meaning only run one decoding step at a time.", ) parser.add_argument( "--delete-ckpt-after-loading", @@ -1076,7 +1079,7 @@ def add_cli_args(parser: argparse.ArgumentParser): type=str, required=False, help="Specify custom warmup functions (csv) to run before server starts eg. --warmups=warmup_name1,warmup_name2 " - "will run the functions `warmup_name1` and `warmup_name2` specified in warmup.py before the server starts listening for requests", + "will run the functions `warmup_name1` and `warmup_name2` specified in warmup.py before the server starts listening for requests", ) # Debug tensor dumps From b9c4f3244f0bec4f6a6bce10e141184a72ba55ef Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 08:43:14 +0800 Subject: [PATCH 0064/1089] more --- python/sglang/srt/managers/expert_distribution.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 226256ed218..d3563a29ea8 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -9,8 +9,9 @@ logger = logging.getLogger(__name__) -# global expert distribution recording class ExpertDistributionRecorder: + """Global expert distribution recording""" + # This class is a singleton class def __new__(cls): if not hasattr(cls, "instance"): From c3d69c8d5f2f97b24cd5b1dfaf4092730cadaa01 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 08:44:57 +0800 Subject: [PATCH 0065/1089] more --- python/sglang/srt/layers/moe/topk.py | 4 +- .../srt/managers/expert_distribution.py | 2 + python/sglang/srt/managers/scheduler.py | 4 +- python/sglang/srt/models/deepseek_v2.py | 4 +- python/sglang/srt/models/qwen2_moe.py | 124 +++++++++--------- 5 files changed, 65 insertions(+), 73 deletions(-) diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 29984f3f25f..5e46e73aea3 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -17,15 +17,13 @@ import torch import torch.nn.functional as F -from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder +from sglang.srt.managers.expert_distribution import expert_distribution_recorder from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip _is_cuda = is_cuda() _is_hip = is_hip() -expert_distribution_recorder = ExpertDistributionRecorder() - def fused_topk_native( hidden_states: torch.Tensor, diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index d3563a29ea8..49a0438306e 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -80,3 +80,5 @@ def dump_record(self): for expert_idx, count in layer_results.items(): fd.write(f"{layer_idx},{expert_idx},{count}\n") self.reset() + +expert_distribution_recorder = ExpertDistributionRecorder() diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 4a434cc5ad6..40e9bfef992 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -53,7 +53,7 @@ from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.layers.dp_attention import compute_dp_attention_world_info from sglang.srt.layers.logits_processor import LogitsProcessorOutput -from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder +from sglang.srt.managers.expert_distribution import expert_distribution_recorder from sglang.srt.managers.io_struct import ( AbortReq, CloseSessionReqInput, @@ -131,8 +131,6 @@ ) from sglang.utils import TypeBasedDispatcher, get_exception_traceback -expert_distribution_recorder = ExpertDistributionRecorder() - logger = logging.getLogger(__name__) # Test retract decode for debugging purposes diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 37760407bf5..68fa39d77ed 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -66,7 +66,7 @@ ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder +from sglang.srt.managers.expert_distribution import expert_distribution_recorder from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_loader.weight_utils import default_weight_loader @@ -85,8 +85,6 @@ decode_attention_fwd_grouped_rope, ) -expert_distribution_recorder = ExpertDistributionRecorder() - class DeepseekV2MLP(nn.Module): def __init__( diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index fa00b35e1ac..119ac10452c 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -20,9 +20,6 @@ import torch import torch.nn.functional as F -from torch import nn -from transformers import PretrainedConfig - from sglang.srt.distributed import ( get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, @@ -44,23 +41,23 @@ ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder +from sglang.srt.managers.expert_distribution import expert_distribution_recorder from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import add_prefix - -expert_distribution_recorder = ExpertDistributionRecorder() +from torch import nn +from transformers import PretrainedConfig class Qwen2MoeMLP(nn.Module): def __init__( - self, - hidden_size: int, - intermediate_size: int, - hidden_act: str, - quant_config: Optional[QuantizationConfig] = None, - reduce_results: bool = True, - prefix: str = "", + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True, + prefix: str = "", ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( @@ -94,10 +91,10 @@ def forward(self, x): class Qwen2MoeSparseMoeBlock(nn.Module): def __init__( - self, - config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.tp_size = get_tensor_model_parallel_world_size() @@ -147,7 +144,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: shared_output = self.shared_expert(hidden_states) if self.shared_expert_gate is not None: shared_output = ( - F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_output + F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_output ) # router_logits: (num_tokens, n_experts) @@ -165,17 +162,17 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class Qwen2MoeAttention(nn.Module): def __init__( - self, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - layer_id: int = 0, - rope_theta: float = 10000, - rope_scaling: Optional[Dict[str, Any]] = None, - max_position_embeddings: int = 8192, - qkv_bias: int = True, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + layer_id: int = 0, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + qkv_bias: int = True, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = hidden_size @@ -196,7 +193,7 @@ def __init__( self.head_dim = hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim - self.scaling = self.head_dim**-0.5 + self.scaling = self.head_dim ** -0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings @@ -235,10 +232,10 @@ def __init__( ) def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - forward_batch: ForwardBatch, + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) @@ -250,11 +247,11 @@ def forward( class Qwen2MoeDecoderLayer(nn.Module): def __init__( - self, - config: PretrainedConfig, - layer_id: int, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", + self, + config: PretrainedConfig, + layer_id: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -282,7 +279,7 @@ def __init__( [] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers ) if (layer_id not in mlp_only_layers) and ( - config.num_experts > 0 and (layer_id + 1) % config.decoder_sparse_step == 0 + config.num_experts > 0 and (layer_id + 1) % config.decoder_sparse_step == 0 ): self.mlp = Qwen2MoeSparseMoeBlock( config=config, @@ -303,11 +300,11 @@ def __init__( ) def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - forward_batch: ForwardBatch, - residual: Optional[torch.Tensor], + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + residual: Optional[torch.Tensor], ) -> torch.Tensor: # Self Attention if residual is None: @@ -329,10 +326,10 @@ def forward( class Qwen2MoeModel(nn.Module): def __init__( - self, - config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.padding_idx = config.pad_token_id @@ -357,11 +354,11 @@ def __init__( self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - forward_batch: ForwardBatch, - input_embeds: torch.Tensor = None, + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, ) -> torch.Tensor: if input_embeds is None: hidden_states = self.embed_tokens(input_ids) @@ -379,14 +376,13 @@ def forward( class Qwen2MoeForCausalLM(nn.Module): - fall_back_to_pt_during_load = False def __init__( - self, - config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config @@ -404,11 +400,11 @@ def __init__( @torch.no_grad() def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - forward_batch: ForwardBatch, - input_embeds: torch.Tensor = None, + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) return self.logits_processor( From 0fa2355e22d754da1260c251ecbb4f9664b5259a Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 08:45:08 +0800 Subject: [PATCH 0066/1089] more --- python/sglang/srt/managers/expert_distribution.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 49a0438306e..5a529ab51a4 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -1,4 +1,3 @@ -import json import logging import time from collections import defaultdict @@ -9,15 +8,9 @@ logger = logging.getLogger(__name__) -class ExpertDistributionRecorder: +class _ExpertDistributionRecorder: """Global expert distribution recording""" - # This class is a singleton class - def __new__(cls): - if not hasattr(cls, "instance"): - cls.instance = super(ExpertDistributionRecorder, cls).__new__(cls) - return cls.instance - def __init__(self): # the length of the dictionary is the number of layers # the length of the list is the number of tokens @@ -81,4 +74,5 @@ def dump_record(self): fd.write(f"{layer_idx},{expert_idx},{count}\n") self.reset() -expert_distribution_recorder = ExpertDistributionRecorder() + +expert_distribution_recorder = _ExpertDistributionRecorder() From a930070911473bce6bb6ba510b8c258ac8c0132c Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 08:45:51 +0800 Subject: [PATCH 0067/1089] more --- .../srt/managers/expert_distribution.py | 10 +- python/sglang/srt/models/deepseek_v2.py | 10 +- python/sglang/srt/models/qwen2_moe.py | 122 +++++++++--------- 3 files changed, 74 insertions(+), 68 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 5a529ab51a4..22917a10bfa 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -1,6 +1,7 @@ import logging import time from collections import defaultdict +from contextlib import contextmanager from typing import Dict, List, Tuple import torch @@ -21,8 +22,13 @@ def __init__(self): self._record = False self._current_layer_id = "UNKNOWN" - def set_current_layer(self, layer_idx): - self._current_layer_id = layer_idx + @contextmanager + def with_current_layer(self, layer_idx): + TODO + + # TODO + # def set_current_layer(self, layer_idx): + # self._current_layer_id = layer_idx def record_new_token(self, topk_ids): if not self._record: diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 68fa39d77ed..6c0b5a3d124 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1270,11 +1270,11 @@ def forward( residual = None for i in range(len(self.layers)): - expert_distribution_recorder.set_current_layer(i) - layer = self.layers[i] - hidden_states, residual = layer( - positions, hidden_states, forward_batch, residual - ) + with expert_distribution_recorder.with_current_layer(i): + layer = self.layers[i] + hidden_states, residual = layer( + positions, hidden_states, forward_batch, residual + ) if not forward_batch.forward_mode.is_idle(): if residual is None: hidden_states = self.norm(hidden_states) diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 119ac10452c..57f7d111650 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -51,13 +51,13 @@ class Qwen2MoeMLP(nn.Module): def __init__( - self, - hidden_size: int, - intermediate_size: int, - hidden_act: str, - quant_config: Optional[QuantizationConfig] = None, - reduce_results: bool = True, - prefix: str = "", + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True, + prefix: str = "", ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( @@ -91,10 +91,10 @@ def forward(self, x): class Qwen2MoeSparseMoeBlock(nn.Module): def __init__( - self, - config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.tp_size = get_tensor_model_parallel_world_size() @@ -144,7 +144,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: shared_output = self.shared_expert(hidden_states) if self.shared_expert_gate is not None: shared_output = ( - F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_output + F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_output ) # router_logits: (num_tokens, n_experts) @@ -162,17 +162,17 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class Qwen2MoeAttention(nn.Module): def __init__( - self, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - layer_id: int = 0, - rope_theta: float = 10000, - rope_scaling: Optional[Dict[str, Any]] = None, - max_position_embeddings: int = 8192, - qkv_bias: int = True, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + layer_id: int = 0, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + qkv_bias: int = True, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = hidden_size @@ -232,10 +232,10 @@ def __init__( ) def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - forward_batch: ForwardBatch, + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) @@ -247,11 +247,11 @@ def forward( class Qwen2MoeDecoderLayer(nn.Module): def __init__( - self, - config: PretrainedConfig, - layer_id: int, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", + self, + config: PretrainedConfig, + layer_id: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -279,7 +279,7 @@ def __init__( [] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers ) if (layer_id not in mlp_only_layers) and ( - config.num_experts > 0 and (layer_id + 1) % config.decoder_sparse_step == 0 + config.num_experts > 0 and (layer_id + 1) % config.decoder_sparse_step == 0 ): self.mlp = Qwen2MoeSparseMoeBlock( config=config, @@ -300,11 +300,11 @@ def __init__( ) def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - forward_batch: ForwardBatch, - residual: Optional[torch.Tensor], + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + residual: Optional[torch.Tensor], ) -> torch.Tensor: # Self Attention if residual is None: @@ -326,10 +326,10 @@ def forward( class Qwen2MoeModel(nn.Module): def __init__( - self, - config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.padding_idx = config.pad_token_id @@ -354,11 +354,11 @@ def __init__( self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - forward_batch: ForwardBatch, - input_embeds: torch.Tensor = None, + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, ) -> torch.Tensor: if input_embeds is None: hidden_states = self.embed_tokens(input_ids) @@ -366,11 +366,11 @@ def forward( hidden_states = input_embeds residual = None for i in range(len(self.layers)): - expert_distribution_recorder.set_current_layer(i) - layer = self.layers[i] - hidden_states, residual = layer( - positions, hidden_states, forward_batch, residual - ) + with expert_distribution_recorder.with_current_layer(i): + layer = self.layers[i] + hidden_states, residual = layer( + positions, hidden_states, forward_batch, residual + ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -379,10 +379,10 @@ class Qwen2MoeForCausalLM(nn.Module): fall_back_to_pt_during_load = False def __init__( - self, - config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config @@ -400,11 +400,11 @@ def __init__( @torch.no_grad() def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - forward_batch: ForwardBatch, - input_embeds: torch.Tensor = None, + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) return self.logits_processor( From 4a804ee8519a285f65af40817fafefa742a9e6d5 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 08:46:48 +0800 Subject: [PATCH 0068/1089] more --- python/sglang/srt/utils.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 9dc2222913d..e013230b764 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -43,7 +43,7 @@ from io import BytesIO from multiprocessing.reduction import ForkingPickler from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Protocol, Set, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Protocol, Set, Tuple, Union, TypeVar, Generic import numpy as np import psutil @@ -1834,3 +1834,25 @@ def flatten_nested_list(nested_list): ] else: return [nested_list] + +T = TypeVar('T') + + +class Withable(Generic[T]): + def __init__(self): + self._value: Optional[T] = None + + @property + def value(self) -> T: + return self._value + + @contextmanager + def with_value(self, new_value: T): + assert self._value is None + self._value = new_value + try: + yield + finally: + assert self._value is new_value + self._value = None + From 61583afca82bf46eb61775b47cb58739cb9fbd70 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 08:47:24 +0800 Subject: [PATCH 0069/1089] more --- python/sglang/srt/managers/expert_distribution.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 22917a10bfa..1f15ab33972 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -1,10 +1,10 @@ import logging import time from collections import defaultdict -from contextlib import contextmanager from typing import Dict, List, Tuple import torch +from sglang.srt.utils import Withable logger = logging.getLogger(__name__) @@ -20,11 +20,12 @@ def __init__(self): list ) self._record = False - self._current_layer_id = "UNKNOWN" + # TODO + # self._current_layer_id = "UNKNOWN" + self._current_layer_id = Withable() - @contextmanager def with_current_layer(self, layer_idx): - TODO + return self._current_layer_id.with_value(layer_idx) # TODO # def set_current_layer(self, layer_idx): From c19a019fbb57c1b273df6ad8c380684c95d2198c Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 08:47:53 +0800 Subject: [PATCH 0070/1089] more --- python/sglang/srt/managers/expert_distribution.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 1f15ab33972..61ac1880d21 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -20,31 +20,25 @@ def __init__(self): list ) self._record = False - # TODO - # self._current_layer_id = "UNKNOWN" self._current_layer_id = Withable() def with_current_layer(self, layer_idx): return self._current_layer_id.with_value(layer_idx) - # TODO - # def set_current_layer(self, layer_idx): - # self._current_layer_id = layer_idx - def record_new_token(self, topk_ids): if not self._record: return topk_ids_list = topk_ids.to("cpu", non_blocking=True).numpy().tolist() torch.cuda.synchronize() for i in topk_ids_list: - self._expert_distribution_record[self._current_layer_id].append(tuple(i)) + self._expert_distribution_record[self._current_layer_id.value].append(tuple(i)) def reset(self): """Reset the expert distribution recorder.""" logger.info("Resetting expert distribution record...") self._record = False self._expert_distribution_record.clear() - self._current_layer_id = "UNKNOWN" + assert self._current_layer_id.value is None def start_record(self): """Start recording the expert distribution. Reset the recorder and set the recording flag to True.""" From fcfd13317e0760b9358bf92ba608383d8094226d Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 08:48:22 +0800 Subject: [PATCH 0071/1089] more --- python/sglang/srt/managers/expert_distribution.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 61ac1880d21..be3b0ba3622 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -42,7 +42,7 @@ def reset(self): def start_record(self): """Start recording the expert distribution. Reset the recorder and set the recording flag to True.""" - if self._record == True: + if self._record: logger.warning( "SGLang server is already recording expert ids. Did you forget to dump the expert ids recorded so far by sending requests to the `/stop_expert_distribution_record` and `/dump_expert_distribution_record` endpoints?" ) @@ -51,7 +51,7 @@ def start_record(self): def stop_record(self): """Stop recording the expert distribution. Set the recording flag to False.""" - if self._record == False: + if not self._record: logger.warning( "SGLang server has not been recording expert ids. Did you forget to start recording by sending request to the `/start_expert_distribution_record` endpoint?" ) From 1148e8f2013669d73b8c5393990af7904b3941fc Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 08:49:12 +0800 Subject: [PATCH 0072/1089] more --- python/sglang/srt/managers/expert_distribution.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index be3b0ba3622..dd25755f034 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -1,5 +1,6 @@ import logging import time +from abc import ABC from collections import defaultdict from typing import Dict, List, Tuple @@ -76,4 +77,12 @@ def dump_record(self): self.reset() +class _ForwardGatherer(ABC): + pass + + +class _SelectExpertsGatherer(_ForwardGatherer): + pass + + expert_distribution_recorder = _ExpertDistributionRecorder() From 5b2e87d3cc4e87b913955f8f7c29b7bceddb7ed0 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 08:49:32 +0800 Subject: [PATCH 0073/1089] more --- python/sglang/srt/layers/moe/topk.py | 2 +- python/sglang/srt/managers/expert_distribution.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 5e46e73aea3..247901d64f3 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -253,6 +253,6 @@ def select_experts( renormalize=renormalize, ) - expert_distribution_recorder.record_new_token(topk_ids) + expert_distribution_recorder.on_select_experts(topk_ids=topk_ids) return topk_weights, topk_ids diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index dd25755f034..46e6553af6b 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -26,7 +26,7 @@ def __init__(self): def with_current_layer(self, layer_idx): return self._current_layer_id.with_value(layer_idx) - def record_new_token(self, topk_ids): + def on_select_experts(self, topk_ids): if not self._record: return topk_ids_list = topk_ids.to("cpu", non_blocking=True).numpy().tolist() From d9e1bc9919d331c3ce5d84d59295b676dd377c37 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 08:49:49 +0800 Subject: [PATCH 0074/1089] more --- python/sglang/srt/managers/expert_distribution.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 46e6553af6b..d8ead47447c 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -29,10 +29,7 @@ def with_current_layer(self, layer_idx): def on_select_experts(self, topk_ids): if not self._record: return - topk_ids_list = topk_ids.to("cpu", non_blocking=True).numpy().tolist() - torch.cuda.synchronize() - for i in topk_ids_list: - self._expert_distribution_record[self._current_layer_id.value].append(tuple(i)) + TODO def reset(self): """Reset the expert distribution recorder.""" @@ -82,7 +79,11 @@ class _ForwardGatherer(ABC): class _SelectExpertsGatherer(_ForwardGatherer): - pass + def on_select_experts(self, topk_ids): + topk_ids_list = topk_ids.to("cpu", non_blocking=True).numpy().tolist() + torch.cuda.synchronize() + for i in topk_ids_list: + self._expert_distribution_record[self._current_layer_id.value].append(tuple(i)) expert_distribution_recorder = _ExpertDistributionRecorder() From 62cc09f8c609f97fe73ba2893fb9b5d62908373e Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 08:49:57 +0800 Subject: [PATCH 0075/1089] more --- python/sglang/srt/managers/expert_distribution.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index d8ead47447c..8280e079a04 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -20,40 +20,40 @@ def __init__(self): self._expert_distribution_record: Dict[int, List[Tuple[int]]] = defaultdict( list ) - self._record = False + self._recording = False self._current_layer_id = Withable() def with_current_layer(self, layer_idx): return self._current_layer_id.with_value(layer_idx) def on_select_experts(self, topk_ids): - if not self._record: + if not self._recording: return TODO def reset(self): """Reset the expert distribution recorder.""" logger.info("Resetting expert distribution record...") - self._record = False + self._recording = False self._expert_distribution_record.clear() assert self._current_layer_id.value is None def start_record(self): """Start recording the expert distribution. Reset the recorder and set the recording flag to True.""" - if self._record: + if self._recording: logger.warning( "SGLang server is already recording expert ids. Did you forget to dump the expert ids recorded so far by sending requests to the `/stop_expert_distribution_record` and `/dump_expert_distribution_record` endpoints?" ) self.reset() - self._record = True + self._recording = True def stop_record(self): """Stop recording the expert distribution. Set the recording flag to False.""" - if not self._record: + if not self._recording: logger.warning( "SGLang server has not been recording expert ids. Did you forget to start recording by sending request to the `/start_expert_distribution_record` endpoint?" ) - self._record = False + self._recording = False def dump_record(self): """Dump the expert distribution record to a file. Reset the recorder after dumping.""" From b1ed73339a2aed482c81db94e32548de607b8abe Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 08:50:55 +0800 Subject: [PATCH 0076/1089] more --- python/sglang/srt/managers/expert_distribution.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 8280e079a04..6a4e39ada68 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -14,14 +14,17 @@ class _ExpertDistributionRecorder: """Global expert distribution recording""" def __init__(self): + self._recording = False + self._current_layer_id = Withable() + self._forward_gatherer: _ForwardGatherer = TODO + + # TODO # the length of the dictionary is the number of layers # the length of the list is the number of tokens # the length of the tuple is topk's k value self._expert_distribution_record: Dict[int, List[Tuple[int]]] = defaultdict( list ) - self._recording = False - self._current_layer_id = Withable() def with_current_layer(self, layer_idx): return self._current_layer_id.with_value(layer_idx) @@ -29,7 +32,7 @@ def with_current_layer(self, layer_idx): def on_select_experts(self, topk_ids): if not self._recording: return - TODO + self._forward_gatherer.on_select_experts(topk_ids) def reset(self): """Reset the expert distribution recorder.""" @@ -75,7 +78,8 @@ def dump_record(self): class _ForwardGatherer(ABC): - pass + def on_select_experts(self, topk_ids): + pass class _SelectExpertsGatherer(_ForwardGatherer): From e7cd1f4ade4e12d82e7bc89b83d42177c0913222 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 08:51:35 +0800 Subject: [PATCH 0077/1089] more --- python/sglang/srt/managers/expert_distribution.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 6a4e39ada68..f143016b167 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -15,7 +15,7 @@ class _ExpertDistributionRecorder: def __init__(self): self._recording = False - self._current_layer_id = Withable() + self._current_layer_idx = Withable() self._forward_gatherer: _ForwardGatherer = TODO # TODO @@ -27,19 +27,19 @@ def __init__(self): ) def with_current_layer(self, layer_idx): - return self._current_layer_id.with_value(layer_idx) + return self._current_layer_idx.with_value(layer_idx) def on_select_experts(self, topk_ids): if not self._recording: return - self._forward_gatherer.on_select_experts(topk_ids) + self._forward_gatherer.on_select_experts(layer_idx=self._current_layer_idx.value, topk_ids=topk_ids) def reset(self): """Reset the expert distribution recorder.""" logger.info("Resetting expert distribution record...") self._recording = False self._expert_distribution_record.clear() - assert self._current_layer_id.value is None + assert self._current_layer_idx.value is None def start_record(self): """Start recording the expert distribution. Reset the recorder and set the recording flag to True.""" @@ -78,16 +78,16 @@ def dump_record(self): class _ForwardGatherer(ABC): - def on_select_experts(self, topk_ids): + def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor): pass class _SelectExpertsGatherer(_ForwardGatherer): - def on_select_experts(self, topk_ids): + def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor): topk_ids_list = topk_ids.to("cpu", non_blocking=True).numpy().tolist() torch.cuda.synchronize() for i in topk_ids_list: - self._expert_distribution_record[self._current_layer_id.value].append(tuple(i)) + self._expert_distribution_record[layer_idx].append(tuple(i)) expert_distribution_recorder = _ExpertDistributionRecorder() From f623f1a7176f590e1c2f08cb2ecb1c4b2b60519d Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 08:53:17 +0800 Subject: [PATCH 0078/1089] more --- .../sglang/srt/layers/moe/ep_moe/token_dispatcher.py | 3 +++ python/sglang/srt/managers/expert_distribution.py | 10 +++++++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index 6b67f6cea87..4bf33649cea 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -1,3 +1,5 @@ +from sglang.srt.managers.expert_distribution import expert_distribution_recorder + try: from deep_ep import Buffer @@ -196,6 +198,7 @@ def dispatch( handle, event, ) = self.dispatch_normal(hidden_states, topk_idx, topk_weights, num_experts) + expert_distribution_recorder.on_deepep_dispatch_normal(num_recv_tokens_per_expert_list) self.tokens_per_expert = torch.tensor( num_recv_tokens_per_expert_list, device=hidden_states.device, diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index f143016b167..1131e728a08 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -29,11 +29,16 @@ def __init__(self): def with_current_layer(self, layer_idx): return self._current_layer_idx.with_value(layer_idx) - def on_select_experts(self, topk_ids): + def on_select_experts(self, topk_ids: torch.Tensor): if not self._recording: return self._forward_gatherer.on_select_experts(layer_idx=self._current_layer_idx.value, topk_ids=topk_ids) + def on_deepep_dispatch_normal(self, num_recv_tokens_per_expert_list: List[int]): + if not self._recording: + return + self._forward_gatherer.on_deepep_dispatch_normal(self._current_layer_idx.value, num_recv_tokens_per_expert_list) + def reset(self): """Reset the expert distribution recorder.""" logger.info("Resetting expert distribution record...") @@ -81,6 +86,9 @@ class _ForwardGatherer(ABC): def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor): pass + def on_deepep_dispatch_normal(self, layer_idx: int, num_recv_tokens_per_expert_list: List[int]): + pass + class _SelectExpertsGatherer(_ForwardGatherer): def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor): From 95d1b38b39fa846cfb446bca7b46c748471d06cf Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 08:54:14 +0800 Subject: [PATCH 0079/1089] more --- python/sglang/srt/managers/expert_distribution.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 1131e728a08..349591b62e8 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -98,4 +98,10 @@ def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor): self._expert_distribution_record[layer_idx].append(tuple(i)) +# TODO Will have a `_DeepepLowLatencyGatherer` after low-latency DeepEP is ready +class _DeepepNormalGatherer(_ForwardGatherer): + def on_deepep_dispatch_normal(self, layer_idx: int, num_recv_tokens_per_expert_list: List[int]): + TODO + + expert_distribution_recorder = _ExpertDistributionRecorder() From d801823ff5f3e922c9ad8146ae4a26884cf5ef42 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 08:54:53 +0800 Subject: [PATCH 0080/1089] more --- python/sglang/srt/managers/expert_distribution.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 349591b62e8..132c423a5ac 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -89,8 +89,19 @@ def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor): def on_deepep_dispatch_normal(self, layer_idx: int, num_recv_tokens_per_expert_list: List[int]): pass + def collect(self): + raise NotImplementedError -class _SelectExpertsGatherer(_ForwardGatherer): + +class _LayerBasedForwardGatherer(_ForwardGatherer): + def __init__(self): + self._num_recv_tokens_per_expert_list_of_layer = {} + + def collect(self): + return TODO + + +class _SelectExpertsForwardGatherer(_ForwardGatherer): def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor): topk_ids_list = topk_ids.to("cpu", non_blocking=True).numpy().tolist() torch.cuda.synchronize() @@ -99,7 +110,7 @@ def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor): # TODO Will have a `_DeepepLowLatencyGatherer` after low-latency DeepEP is ready -class _DeepepNormalGatherer(_ForwardGatherer): +class _DeepepNormalForwardGatherer(_ForwardGatherer): def on_deepep_dispatch_normal(self, layer_idx: int, num_recv_tokens_per_expert_list: List[int]): TODO From dfa311d4d803b38b4347e06cad06395224747c36 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 08:55:37 +0800 Subject: [PATCH 0081/1089] more --- python/sglang/srt/managers/expert_distribution.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 132c423a5ac..70d794b4b8a 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -97,6 +97,11 @@ class _LayerBasedForwardGatherer(_ForwardGatherer): def __init__(self): self._num_recv_tokens_per_expert_list_of_layer = {} + def _on_layer_data(self, layer_idx: int, num_recv_tokens_per_expert_list: List[int]): + # TODO for TBO, we may need to relax this restriction + assert layer_idx not in self._num_recv_tokens_per_expert_list_of_layer + self._num_recv_tokens_per_expert_list_of_layer[layer_idx] = num_recv_tokens_per_expert_list + def collect(self): return TODO From 21dcb36f25da22352e13abc420a357a748a8e9a9 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 08:55:53 +0800 Subject: [PATCH 0082/1089] more --- python/sglang/srt/managers/expert_distribution.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 70d794b4b8a..0be9f02ce03 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -106,7 +106,7 @@ def collect(self): return TODO -class _SelectExpertsForwardGatherer(_ForwardGatherer): +class _SelectExpertsForwardGatherer(_LayerBasedForwardGatherer): def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor): topk_ids_list = topk_ids.to("cpu", non_blocking=True).numpy().tolist() torch.cuda.synchronize() @@ -115,9 +115,9 @@ def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor): # TODO Will have a `_DeepepLowLatencyGatherer` after low-latency DeepEP is ready -class _DeepepNormalForwardGatherer(_ForwardGatherer): +class _DeepepNormalForwardGatherer(_LayerBasedForwardGatherer): def on_deepep_dispatch_normal(self, layer_idx: int, num_recv_tokens_per_expert_list: List[int]): - TODO + self._on_layer_data(layer_idx, num_recv_tokens_per_expert_list) expert_distribution_recorder = _ExpertDistributionRecorder() From f3484c8e6f35a0247dd073d861419235c741ca88 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 08:56:04 +0800 Subject: [PATCH 0083/1089] more --- python/sglang/srt/managers/expert_distribution.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 0be9f02ce03..c0928414e25 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -112,6 +112,7 @@ def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor): torch.cuda.synchronize() for i in topk_ids_list: self._expert_distribution_record[layer_idx].append(tuple(i)) + self._on_layer_data(layer_idx, num_recv_tokens_per_expert_list) # TODO Will have a `_DeepepLowLatencyGatherer` after low-latency DeepEP is ready From cfe4ec2e46839e8790d1baeff677d106ff0ecaab Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 08:56:21 +0800 Subject: [PATCH 0084/1089] more --- python/sglang/srt/managers/expert_distribution.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index c0928414e25..1d4632296ea 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -110,8 +110,11 @@ class _SelectExpertsForwardGatherer(_LayerBasedForwardGatherer): def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor): topk_ids_list = topk_ids.to("cpu", non_blocking=True).numpy().tolist() torch.cuda.synchronize() + + num_recv_tokens_per_expert_list = TODO for i in topk_ids_list: self._expert_distribution_record[layer_idx].append(tuple(i)) + self._on_layer_data(layer_idx, num_recv_tokens_per_expert_list) From acdadc76ae49e29404319777cbe9a8197b51fee5 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 08:56:50 +0800 Subject: [PATCH 0085/1089] more --- python/sglang/srt/managers/expert_distribution.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 1d4632296ea..aca6e346cbd 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -113,7 +113,12 @@ def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor): num_recv_tokens_per_expert_list = TODO for i in topk_ids_list: - self._expert_distribution_record[layer_idx].append(tuple(i)) + expert_distribution_record[layer_idx].append(tuple(i)) + for layer_idx, layer_record in expert_distribution_record.items(): + results[layer_idx] = defaultdict(int) + for token_record in layer_record: + for expert_idx in token_record: + results[layer_idx][expert_idx] += 1 self._on_layer_data(layer_idx, num_recv_tokens_per_expert_list) From 1357690bbe6abd9d43d3644d6634d3c0723aeb38 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 08:57:00 +0800 Subject: [PATCH 0086/1089] more --- python/sglang/srt/managers/expert_distribution.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index aca6e346cbd..8ec2bbe001e 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -113,12 +113,11 @@ def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor): num_recv_tokens_per_expert_list = TODO for i in topk_ids_list: - expert_distribution_record[layer_idx].append(tuple(i)) - for layer_idx, layer_record in expert_distribution_record.items(): - results[layer_idx] = defaultdict(int) - for token_record in layer_record: - for expert_idx in token_record: - results[layer_idx][expert_idx] += 1 + layer_record.append(tuple(i)) + results[layer_idx] = defaultdict(int) + for token_record in layer_record: + for expert_idx in token_record: + results[layer_idx][expert_idx] += 1 self._on_layer_data(layer_idx, num_recv_tokens_per_expert_list) From 10a9b2a89148b8a272a150c9571417ea5e4fd765 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 08:58:43 +0800 Subject: [PATCH 0087/1089] more --- python/sglang/srt/managers/expert_distribution.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 8ec2bbe001e..2a1ff362fdf 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -107,6 +107,7 @@ def collect(self): class _SelectExpertsForwardGatherer(_LayerBasedForwardGatherer): + # pretty slow, but we will use the DeepEP Gatherer in production def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor): topk_ids_list = topk_ids.to("cpu", non_blocking=True).numpy().tolist() torch.cuda.synchronize() From eb4ad457a603dc46264f9695d1c750dc940e8a74 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 08:59:25 +0800 Subject: [PATCH 0088/1089] more --- python/sglang/srt/managers/expert_distribution.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 2a1ff362fdf..e250a150445 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -112,13 +112,10 @@ def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor): topk_ids_list = topk_ids.to("cpu", non_blocking=True).numpy().tolist() torch.cuda.synchronize() - num_recv_tokens_per_expert_list = TODO - for i in topk_ids_list: - layer_record.append(tuple(i)) - results[layer_idx] = defaultdict(int) - for token_record in layer_record: + num_recv_tokens_per_expert_list = [0] * num_local_physical_experts + for token_record in topk_ids_list: for expert_idx in token_record: - results[layer_idx][expert_idx] += 1 + num_recv_tokens_per_expert_list[expert_idx] += 1 self._on_layer_data(layer_idx, num_recv_tokens_per_expert_list) From 12b75ffeeb0f0f31b8a488f533c2cd4b0d3f1d07 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 08:59:56 +0800 Subject: [PATCH 0089/1089] more --- python/sglang/srt/managers/expert_distribution.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index e250a150445..63367fabad2 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -123,6 +123,7 @@ def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor): # TODO Will have a `_DeepepLowLatencyGatherer` after low-latency DeepEP is ready class _DeepepNormalForwardGatherer(_LayerBasedForwardGatherer): def on_deepep_dispatch_normal(self, layer_idx: int, num_recv_tokens_per_expert_list: List[int]): + assert isinstance(num_recv_tokens_per_expert_list, list) self._on_layer_data(layer_idx, num_recv_tokens_per_expert_list) From 61fb36496df7788f703a4a01a7c31c7176739ccc Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:00:25 +0800 Subject: [PATCH 0090/1089] more --- python/sglang/srt/managers/expert_distribution.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 63367fabad2..c5840db0719 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -120,11 +120,15 @@ def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor): self._on_layer_data(layer_idx, num_recv_tokens_per_expert_list) -# TODO Will have a `_DeepepLowLatencyGatherer` after low-latency DeepEP is ready class _DeepepNormalForwardGatherer(_LayerBasedForwardGatherer): def on_deepep_dispatch_normal(self, layer_idx: int, num_recv_tokens_per_expert_list: List[int]): assert isinstance(num_recv_tokens_per_expert_list, list) self._on_layer_data(layer_idx, num_recv_tokens_per_expert_list) +# TODO Wait for LowLatency DeepEP merging +class _DeepepLowLatencyForwardGatherer(_ForwardGatherer): + pass + + expert_distribution_recorder = _ExpertDistributionRecorder() From 0c2263c29eeaa5e8d7d3e1522afea04ecbeb1aa0 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:00:39 +0800 Subject: [PATCH 0091/1089] more --- python/sglang/srt/managers/expert_distribution.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index c5840db0719..012d3d171b4 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -127,6 +127,7 @@ def on_deepep_dispatch_normal(self, layer_idx: int, num_recv_tokens_per_expert_l # TODO Wait for LowLatency DeepEP merging +# e.g. use naive tensor copying class _DeepepLowLatencyForwardGatherer(_ForwardGatherer): pass From 6a0f33829096eff3b0f3914ba3985f6361f8d366 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:02:32 +0800 Subject: [PATCH 0092/1089] more --- .../srt/managers/expert_distribution.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 012d3d171b4..6db643123b1 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -2,6 +2,7 @@ import time from abc import ABC from collections import defaultdict +from contextlib import contextmanager from typing import Dict, List, Tuple import torch @@ -29,6 +30,18 @@ def __init__(self): def with_current_layer(self, layer_idx): return self._current_layer_idx.with_value(layer_idx) + @contextmanager + def with_forward_pass(self): + try: + yield + finally: + self._on_forward_pass_end() + + def _on_forward_pass_end(self): + data = self._forward_gatherer.collect() + TODO_use_data + self._forward_gatherer.reset() + def on_select_experts(self, topk_ids: torch.Tensor): if not self._recording: return @@ -89,6 +102,9 @@ def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor): def on_deepep_dispatch_normal(self, layer_idx: int, num_recv_tokens_per_expert_list: List[int]): pass + def reset(self): + raise NotImplementedError + def collect(self): raise NotImplementedError @@ -102,6 +118,9 @@ def _on_layer_data(self, layer_idx: int, num_recv_tokens_per_expert_list: List[i assert layer_idx not in self._num_recv_tokens_per_expert_list_of_layer self._num_recv_tokens_per_expert_list_of_layer[layer_idx] = num_recv_tokens_per_expert_list + def reset(self): + self._num_recv_tokens_per_expert_list_of_layer.clear() + def collect(self): return TODO From 86861d6d9701456d27595d935657d5260c084fd7 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:03:15 +0800 Subject: [PATCH 0093/1089] more --- python/sglang/srt/managers/expert_distribution.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 6db643123b1..66805f8b21b 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -105,7 +105,7 @@ def on_deepep_dispatch_normal(self, layer_idx: int, num_recv_tokens_per_expert_l def reset(self): raise NotImplementedError - def collect(self): + def collect(self) -> torch.Tensor: raise NotImplementedError @@ -121,8 +121,12 @@ def _on_layer_data(self, layer_idx: int, num_recv_tokens_per_expert_list: List[i def reset(self): self._num_recv_tokens_per_expert_list_of_layer.clear() - def collect(self): - return TODO + def collect(self) -> torch.Tensor: + data = [ + TODO + for layer_index in range(num_layers) + ] + return torch.tensor(data) class _SelectExpertsForwardGatherer(_LayerBasedForwardGatherer): From c3d806a3a271b41e5b565fd2f04ffe10e6b43645 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:03:45 +0800 Subject: [PATCH 0094/1089] more --- python/sglang/srt/managers/expert_distribution.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 66805f8b21b..9406ec7d758 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -116,6 +116,7 @@ def __init__(self): def _on_layer_data(self, layer_idx: int, num_recv_tokens_per_expert_list: List[int]): # TODO for TBO, we may need to relax this restriction assert layer_idx not in self._num_recv_tokens_per_expert_list_of_layer + assert 0 <= layer_idx < num_layers self._num_recv_tokens_per_expert_list_of_layer[layer_idx] = num_recv_tokens_per_expert_list def reset(self): @@ -123,7 +124,7 @@ def reset(self): def collect(self) -> torch.Tensor: data = [ - TODO + self._num_recv_tokens_per_expert_list_of_layer.get(layer_index) or ([0] * num_local_physical_experts) for layer_index in range(num_layers) ] return torch.tensor(data) From 99a76c9dad0cc8cfb4ef4c9461c9f29547babdd7 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:05:04 +0800 Subject: [PATCH 0095/1089] more --- python/sglang/srt/managers/expert_distribution.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 9406ec7d758..c09d0cdb49e 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -52,6 +52,8 @@ def on_deepep_dispatch_normal(self, num_recv_tokens_per_expert_list: List[int]): return self._forward_gatherer.on_deepep_dispatch_normal(self._current_layer_idx.value, num_recv_tokens_per_expert_list) + # -------------------------------- TODO --------------------------------------- + def reset(self): """Reset the expert distribution recorder.""" logger.info("Resetting expert distribution record...") From 8d6f2cb1219f4943e984d7e3fadeb7dd9ad33c8b Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:05:39 +0800 Subject: [PATCH 0096/1089] more --- python/sglang/srt/managers/expert_distribution.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index c09d0cdb49e..176ab948b47 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -158,4 +158,16 @@ class _DeepepLowLatencyForwardGatherer(_ForwardGatherer): pass +class _Accumulator(ABC): + pass + + +class _DetailAccumulator(_Accumulator): + pass + + +class _StatAccumulator(_Accumulator): + pass + + expert_distribution_recorder = _ExpertDistributionRecorder() From 47c010c5b02d51559f23785985f37a2212779b7d Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:05:49 +0800 Subject: [PATCH 0097/1089] more --- python/sglang/srt/managers/expert_distribution.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 176ab948b47..4ed77b5ef41 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -18,6 +18,7 @@ def __init__(self): self._recording = False self._current_layer_idx = Withable() self._forward_gatherer: _ForwardGatherer = TODO + self._accumulator: _Accumulator = TODO # TODO # the length of the dictionary is the number of layers From 4a064a2ce0158c1594e3906143f37d471621ff09 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:06:19 +0800 Subject: [PATCH 0098/1089] more --- python/sglang/srt/managers/expert_distribution.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 4ed77b5ef41..2e617d7140f 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -39,8 +39,8 @@ def with_forward_pass(self): self._on_forward_pass_end() def _on_forward_pass_end(self): - data = self._forward_gatherer.collect() - TODO_use_data + forward_pass_data = self._forward_gatherer.collect() + self._accumulator.append(forward_pass_data) self._forward_gatherer.reset() def on_select_experts(self, topk_ids: torch.Tensor): @@ -160,15 +160,18 @@ class _DeepepLowLatencyForwardGatherer(_ForwardGatherer): class _Accumulator(ABC): - pass + def append(self, forward_pass_data: torch.Tensor): + raise NotImplementedError class _DetailAccumulator(_Accumulator): - pass + def append(self, forward_pass_data: torch.Tensor): + TODO class _StatAccumulator(_Accumulator): - pass + def append(self, forward_pass_data: torch.Tensor): + TODO expert_distribution_recorder = _ExpertDistributionRecorder() From 634f28f79933bf1d599a71f91e49d4cbe3c14651 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:07:07 +0800 Subject: [PATCH 0099/1089] more --- python/sglang/srt/model_executor/model_runner.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index f5405c9afdc..bb3a679295b 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -45,6 +45,7 @@ from sglang.srt.layers.sampler import Sampler from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model from sglang.srt.lora.lora_manager import LoRAManager +from sglang.srt.managers.expert_distribution import expert_distribution_recorder from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.mem_cache.memory_pool import ( DoubleSparseTokenToKVPool, @@ -971,6 +972,10 @@ def forward_idle(self, forward_batch: ForwardBatch): def forward( self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False ) -> LogitsProcessorOutput: + with expert_distribution_recorder.with_forward_pass(): + return self._forward_raw(forward_batch, skip_attn_backend_init) + + def _forward_raw(self, forward_batch: ForwardBatch, skip_attn_backend_init: bool) -> LogitsProcessorOutput: if ( forward_batch.forward_mode.is_cuda_graph() and self.cuda_graph_runner From a495780d0b81d62b0577abbf8c3f9df5dff0538b Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:08:11 +0800 Subject: [PATCH 0100/1089] more --- python/sglang/srt/managers/expert_distribution.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 2e617d7140f..9f802d31bdc 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -1,5 +1,4 @@ import logging -import time from abc import ABC from collections import defaultdict from contextlib import contextmanager @@ -87,14 +86,6 @@ def dump_record(self): for token_record in layer_record: for expert_idx in token_record: results[layer_idx][expert_idx] += 1 - with open( - f"expert_distribution_rank{torch.distributed.get_rank()}_timestamp{time.time()}.csv", - "w", - ) as fd: - fd.write("layer_id,expert_id,count\n") - for layer_idx, layer_results in results.items(): - for expert_idx, count in layer_results.items(): - fd.write(f"{layer_idx},{expert_idx},{count}\n") self.reset() From 03ae2eda03a3ad6c8b5269db45c2ce722088c4c5 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:09:04 +0800 Subject: [PATCH 0101/1089] more --- .../srt/managers/expert_distribution.py | 22 ++++++++++--------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 9f802d31bdc..05487535b13 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -2,7 +2,7 @@ from abc import ABC from collections import defaultdict from contextlib import contextmanager -from typing import Dict, List, Tuple +from typing import List import torch from sglang.srt.utils import Withable @@ -19,14 +19,6 @@ def __init__(self): self._forward_gatherer: _ForwardGatherer = TODO self._accumulator: _Accumulator = TODO - # TODO - # the length of the dictionary is the number of layers - # the length of the list is the number of tokens - # the length of the tuple is topk's k value - self._expert_distribution_record: Dict[int, List[Tuple[int]]] = defaultdict( - list - ) - def with_current_layer(self, layer_idx): return self._current_layer_idx.with_value(layer_idx) @@ -58,7 +50,8 @@ def reset(self): """Reset the expert distribution recorder.""" logger.info("Resetting expert distribution record...") self._recording = False - self._expert_distribution_record.clear() + self._forward_gatherer.reset() + self._accumulator.reset() assert self._current_layer_idx.value is None def start_record(self): @@ -154,15 +147,24 @@ class _Accumulator(ABC): def append(self, forward_pass_data: torch.Tensor): raise NotImplementedError + def reset(self): + raise NotImplementedError + class _DetailAccumulator(_Accumulator): def append(self, forward_pass_data: torch.Tensor): TODO + def reset(self): + TODO + class _StatAccumulator(_Accumulator): def append(self, forward_pass_data: torch.Tensor): TODO + def reset(self): + TODO + expert_distribution_recorder = _ExpertDistributionRecorder() From 7ece402b1852f8415e831977379e7fcb0fcd462e Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:09:23 +0800 Subject: [PATCH 0102/1089] more --- python/sglang/srt/managers/expert_distribution.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 05487535b13..ae7e0ae1c82 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -48,11 +48,11 @@ def on_deepep_dispatch_normal(self, num_recv_tokens_per_expert_list: List[int]): def reset(self): """Reset the expert distribution recorder.""" - logger.info("Resetting expert distribution record...") + logger.info("Resetting ExpertDistributionRecorder...") self._recording = False + assert self._current_layer_idx.value is None self._forward_gatherer.reset() self._accumulator.reset() - assert self._current_layer_idx.value is None def start_record(self): """Start recording the expert distribution. Reset the recorder and set the recording flag to True.""" From 057d74067df5b2666972d9b2b086ca33648af149 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:09:57 +0800 Subject: [PATCH 0103/1089] more --- python/sglang/srt/managers/expert_distribution.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index ae7e0ae1c82..15717ce93d9 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -1,6 +1,5 @@ import logging from abc import ABC -from collections import defaultdict from contextlib import contextmanager from typing import List @@ -55,7 +54,7 @@ def reset(self): self._accumulator.reset() def start_record(self): - """Start recording the expert distribution. Reset the recorder and set the recording flag to True.""" + """Start recording the expert distribution.""" if self._recording: logger.warning( "SGLang server is already recording expert ids. Did you forget to dump the expert ids recorded so far by sending requests to the `/stop_expert_distribution_record` and `/dump_expert_distribution_record` endpoints?" @@ -64,7 +63,7 @@ def start_record(self): self._recording = True def stop_record(self): - """Stop recording the expert distribution. Set the recording flag to False.""" + """Stop recording the expert distribution.""" if not self._recording: logger.warning( "SGLang server has not been recording expert ids. Did you forget to start recording by sending request to the `/start_expert_distribution_record` endpoint?" @@ -72,14 +71,9 @@ def stop_record(self): self._recording = False def dump_record(self): - """Dump the expert distribution record to a file. Reset the recorder after dumping.""" - results = {} - for layer_idx, layer_record in self._expert_distribution_record.items(): - results[layer_idx] = defaultdict(int) - for token_record in layer_record: - for expert_idx in token_record: - results[layer_idx][expert_idx] += 1 + """Dump the expert distribution record and reset the recorder after dumping.""" self.reset() + return TODO class _ForwardGatherer(ABC): From e97bd895fe710215c37cd8371741c72b40774bc8 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:10:10 +0800 Subject: [PATCH 0104/1089] more --- python/sglang/srt/managers/expert_distribution.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 15717ce93d9..2b06da0f4e5 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -43,9 +43,7 @@ def on_deepep_dispatch_normal(self, num_recv_tokens_per_expert_list: List[int]): return self._forward_gatherer.on_deepep_dispatch_normal(self._current_layer_idx.value, num_recv_tokens_per_expert_list) - # -------------------------------- TODO --------------------------------------- - - def reset(self): + def _reset(self): """Reset the expert distribution recorder.""" logger.info("Resetting ExpertDistributionRecorder...") self._recording = False @@ -59,7 +57,7 @@ def start_record(self): logger.warning( "SGLang server is already recording expert ids. Did you forget to dump the expert ids recorded so far by sending requests to the `/stop_expert_distribution_record` and `/dump_expert_distribution_record` endpoints?" ) - self.reset() + self._reset() self._recording = True def stop_record(self): @@ -72,7 +70,7 @@ def stop_record(self): def dump_record(self): """Dump the expert distribution record and reset the recorder after dumping.""" - self.reset() + self._reset() return TODO From c0d97d6916d0b3c898930f02e2191f3f6d71a9ff Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:10:36 +0800 Subject: [PATCH 0105/1089] more --- python/sglang/srt/managers/expert_distribution.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 2b06da0f4e5..f41689940a1 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -70,8 +70,9 @@ def stop_record(self): def dump_record(self): """Dump the expert distribution record and reset the recorder after dumping.""" + output = self._accumulator.dump() self._reset() - return TODO + return output class _ForwardGatherer(ABC): @@ -142,6 +143,9 @@ def append(self, forward_pass_data: torch.Tensor): def reset(self): raise NotImplementedError + def dump(self): + raise NotImplementedError + class _DetailAccumulator(_Accumulator): def append(self, forward_pass_data: torch.Tensor): @@ -150,6 +154,9 @@ def append(self, forward_pass_data: torch.Tensor): def reset(self): TODO + def dump(self): + TODO + class _StatAccumulator(_Accumulator): def append(self, forward_pass_data: torch.Tensor): @@ -158,5 +165,8 @@ def append(self, forward_pass_data: torch.Tensor): def reset(self): TODO + def dump(self): + TODO + expert_distribution_recorder = _ExpertDistributionRecorder() From bd03337a4e3a02f3ef2161b7019f6c0d4cdf7bed Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:11:06 +0800 Subject: [PATCH 0106/1089] more --- python/sglang/srt/managers/expert_distribution.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index f41689940a1..af334a45039 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -9,6 +9,8 @@ logger = logging.getLogger(__name__) +# --------------------------------------- Entrypoint ----------------------------------------- + class _ExpertDistributionRecorder: """Global expert distribution recording""" @@ -75,6 +77,8 @@ def dump_record(self): return output +# --------------------------------------- ForwardGatherer ----------------------------------------- + class _ForwardGatherer(ABC): def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor): pass @@ -136,6 +140,8 @@ class _DeepepLowLatencyForwardGatherer(_ForwardGatherer): pass +# --------------------------------------- Accumulator ----------------------------------------- + class _Accumulator(ABC): def append(self, forward_pass_data: torch.Tensor): raise NotImplementedError From de8f68e9c8c17853ca97bd6cad970928e7d05007 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:11:14 +0800 Subject: [PATCH 0107/1089] more --- python/sglang/srt/managers/expert_distribution.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index af334a45039..cf28dae0de0 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -77,6 +77,9 @@ def dump_record(self): return output +expert_distribution_recorder = _ExpertDistributionRecorder() + + # --------------------------------------- ForwardGatherer ----------------------------------------- class _ForwardGatherer(ABC): @@ -173,6 +176,3 @@ def reset(self): def dump(self): TODO - - -expert_distribution_recorder = _ExpertDistributionRecorder() From a81db0df6a056676e1c663f34bddaa53474a86c0 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:12:36 +0800 Subject: [PATCH 0108/1089] more --- .../sglang/srt/managers/expert_distribution.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index cf28dae0de0..cc4d26db343 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -1,7 +1,7 @@ import logging from abc import ABC from contextlib import contextmanager -from typing import List +from typing import List, Type import torch from sglang.srt.utils import Withable @@ -17,8 +17,8 @@ class _ExpertDistributionRecorder: def __init__(self): self._recording = False self._current_layer_idx = Withable() - self._forward_gatherer: _ForwardGatherer = TODO - self._accumulator: _Accumulator = TODO + self._forward_gatherer = _ForwardGatherer.init_new() + self._accumulator = _Accumulator.init_new() def with_current_layer(self, layer_idx): return self._current_layer_idx.with_value(layer_idx) @@ -83,6 +83,10 @@ def dump_record(self): # --------------------------------------- ForwardGatherer ----------------------------------------- class _ForwardGatherer(ABC): + @staticmethod + def init_new() -> "_ForwardGatherer": + return TODO + def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor): pass @@ -146,6 +150,14 @@ class _DeepepLowLatencyForwardGatherer(_ForwardGatherer): # --------------------------------------- Accumulator ----------------------------------------- class _Accumulator(ABC): + @staticmethod + def init_new() -> "_Accumulator": + return _Accumulator.get_class()() + + @staticmethod + def get_class() -> Type["_Accumulator"]: + return TODO + def append(self, forward_pass_data: torch.Tensor): raise NotImplementedError From 2b7cc463749f2f0e1c5920c9dd93a15d08ea5683 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:13:32 +0800 Subject: [PATCH 0109/1089] more --- python/sglang/srt/managers/expert_distribution.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index cc4d26db343..8a9812bd730 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -1,6 +1,7 @@ import logging from abc import ABC from contextlib import contextmanager +from copy import deepcopy from typing import List, Type import torch @@ -169,14 +170,19 @@ def dump(self): class _DetailAccumulator(_Accumulator): + def __init__(self): + self._records = [] + def append(self, forward_pass_data: torch.Tensor): - TODO + self._records.append(dict( + forward_pass_data=forward_pass_data.tolist(), + )) def reset(self): - TODO + self._records.clear() def dump(self): - TODO + return deepcopy(self._records) class _StatAccumulator(_Accumulator): From cbeae3ac62ec343d6b21a36681faf4c3b9806be9 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:14:17 +0800 Subject: [PATCH 0110/1089] more --- python/sglang/srt/managers/expert_distribution.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 8a9812bd730..2bc1dd906f7 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -186,6 +186,9 @@ def dump(self): class _StatAccumulator(_Accumulator): + def __init__(self): + self._TODO = TODO + def append(self, forward_pass_data: torch.Tensor): TODO From ab7eeef984e5d6ae5cd800d78cc1ba5f0f02ad4f Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:15:14 +0800 Subject: [PATCH 0111/1089] more --- .../srt/managers/expert_distribution.py | 27 ++++++++++--------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 2bc1dd906f7..86041127298 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -18,7 +18,7 @@ class _ExpertDistributionRecorder: def __init__(self): self._recording = False self._current_layer_idx = Withable() - self._forward_gatherer = _ForwardGatherer.init_new() + self._single_pass_gatherer = _SinglePassGatherer.init_new() self._accumulator = _Accumulator.init_new() def with_current_layer(self, layer_idx): @@ -32,26 +32,27 @@ def with_forward_pass(self): self._on_forward_pass_end() def _on_forward_pass_end(self): - forward_pass_data = self._forward_gatherer.collect() + forward_pass_data = self._single_pass_gatherer.collect() self._accumulator.append(forward_pass_data) - self._forward_gatherer.reset() + self._single_pass_gatherer.reset() def on_select_experts(self, topk_ids: torch.Tensor): if not self._recording: return - self._forward_gatherer.on_select_experts(layer_idx=self._current_layer_idx.value, topk_ids=topk_ids) + self._single_pass_gatherer.on_select_experts(layer_idx=self._current_layer_idx.value, topk_ids=topk_ids) def on_deepep_dispatch_normal(self, num_recv_tokens_per_expert_list: List[int]): if not self._recording: return - self._forward_gatherer.on_deepep_dispatch_normal(self._current_layer_idx.value, num_recv_tokens_per_expert_list) + self._single_pass_gatherer.on_deepep_dispatch_normal(self._current_layer_idx.value, + num_recv_tokens_per_expert_list) def _reset(self): """Reset the expert distribution recorder.""" logger.info("Resetting ExpertDistributionRecorder...") self._recording = False assert self._current_layer_idx.value is None - self._forward_gatherer.reset() + self._single_pass_gatherer.reset() self._accumulator.reset() def start_record(self): @@ -81,11 +82,11 @@ def dump_record(self): expert_distribution_recorder = _ExpertDistributionRecorder() -# --------------------------------------- ForwardGatherer ----------------------------------------- +# --------------------------------------- SinglePassGatherer ----------------------------------------- -class _ForwardGatherer(ABC): +class _SinglePassGatherer(ABC): @staticmethod - def init_new() -> "_ForwardGatherer": + def init_new() -> "_SinglePassGatherer": return TODO def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor): @@ -101,7 +102,7 @@ def collect(self) -> torch.Tensor: raise NotImplementedError -class _LayerBasedForwardGatherer(_ForwardGatherer): +class _LayerBasedSinglePassGatherer(_SinglePassGatherer): def __init__(self): self._num_recv_tokens_per_expert_list_of_layer = {} @@ -122,7 +123,7 @@ def collect(self) -> torch.Tensor: return torch.tensor(data) -class _SelectExpertsForwardGatherer(_LayerBasedForwardGatherer): +class _SelectExpertsSinglePassGatherer(_LayerBasedSinglePassGatherer): # pretty slow, but we will use the DeepEP Gatherer in production def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor): topk_ids_list = topk_ids.to("cpu", non_blocking=True).numpy().tolist() @@ -136,7 +137,7 @@ def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor): self._on_layer_data(layer_idx, num_recv_tokens_per_expert_list) -class _DeepepNormalForwardGatherer(_LayerBasedForwardGatherer): +class _DeepepNormalSinglePassGatherer(_LayerBasedSinglePassGatherer): def on_deepep_dispatch_normal(self, layer_idx: int, num_recv_tokens_per_expert_list: List[int]): assert isinstance(num_recv_tokens_per_expert_list, list) self._on_layer_data(layer_idx, num_recv_tokens_per_expert_list) @@ -144,7 +145,7 @@ def on_deepep_dispatch_normal(self, layer_idx: int, num_recv_tokens_per_expert_l # TODO Wait for LowLatency DeepEP merging # e.g. use naive tensor copying -class _DeepepLowLatencyForwardGatherer(_ForwardGatherer): +class _DeepepLowLatencySinglePassGatherer(_SinglePassGatherer): pass From 3b7b887a323a246b612332b6f7d35bc835909836 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:16:04 +0800 Subject: [PATCH 0112/1089] more --- python/sglang/srt/managers/expert_distribution.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 86041127298..5b84bb71e39 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -32,8 +32,8 @@ def with_forward_pass(self): self._on_forward_pass_end() def _on_forward_pass_end(self): - forward_pass_data = self._single_pass_gatherer.collect() - self._accumulator.append(forward_pass_data) + single_pass_count = self._single_pass_gatherer.collect() + self._accumulator.append(single_pass_count) self._single_pass_gatherer.reset() def on_select_experts(self, topk_ids: torch.Tensor): @@ -160,7 +160,7 @@ def init_new() -> "_Accumulator": def get_class() -> Type["_Accumulator"]: return TODO - def append(self, forward_pass_data: torch.Tensor): + def append(self, single_pass_count: torch.Tensor): raise NotImplementedError def reset(self): @@ -174,9 +174,9 @@ class _DetailAccumulator(_Accumulator): def __init__(self): self._records = [] - def append(self, forward_pass_data: torch.Tensor): + def append(self, single_pass_count: torch.Tensor): self._records.append(dict( - forward_pass_data=forward_pass_data.tolist(), + single_pass_count=single_pass_count.tolist(), )) def reset(self): @@ -190,7 +190,7 @@ class _StatAccumulator(_Accumulator): def __init__(self): self._TODO = TODO - def append(self, forward_pass_data: torch.Tensor): + def append(self, single_pass_count: torch.Tensor): TODO def reset(self): From 56f5e093ef2735ac81c5a373f79183ab938d59c6 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:16:38 +0800 Subject: [PATCH 0113/1089] more --- python/sglang/srt/managers/expert_distribution.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 5b84bb71e39..f5c24b6bdc6 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -32,8 +32,8 @@ def with_forward_pass(self): self._on_forward_pass_end() def _on_forward_pass_end(self): - single_pass_count = self._single_pass_gatherer.collect() - self._accumulator.append(single_pass_count) + single_pass_physical_count = self._single_pass_gatherer.collect() + self._accumulator.append(single_pass_physical_count) self._single_pass_gatherer.reset() def on_select_experts(self, topk_ids: torch.Tensor): @@ -160,7 +160,7 @@ def init_new() -> "_Accumulator": def get_class() -> Type["_Accumulator"]: return TODO - def append(self, single_pass_count: torch.Tensor): + def append(self, single_pass_physical_count: torch.Tensor): raise NotImplementedError def reset(self): @@ -174,9 +174,9 @@ class _DetailAccumulator(_Accumulator): def __init__(self): self._records = [] - def append(self, single_pass_count: torch.Tensor): + def append(self, single_pass_physical_count: torch.Tensor): self._records.append(dict( - single_pass_count=single_pass_count.tolist(), + physical_count=single_pass_physical_count.tolist(), )) def reset(self): @@ -188,9 +188,9 @@ def dump(self): class _StatAccumulator(_Accumulator): def __init__(self): - self._TODO = TODO + self._physical_count = torch.zeros((num_layers, num_local_physical_experts)) - def append(self, single_pass_count: torch.Tensor): + def append(self, single_pass_physical_count: torch.Tensor): TODO def reset(self): From d1716ae5582cd1286ae3732c6b7400ec880be22b Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:16:50 +0800 Subject: [PATCH 0114/1089] more --- python/sglang/srt/managers/expert_distribution.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index f5c24b6bdc6..95c04c463f1 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -191,7 +191,7 @@ def __init__(self): self._physical_count = torch.zeros((num_layers, num_local_physical_experts)) def append(self, single_pass_physical_count: torch.Tensor): - TODO + self._physical_count += single_pass_physical_count def reset(self): TODO From 942b8e2c7ab208dd497c2ce0b16580319ead31dd Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:17:18 +0800 Subject: [PATCH 0115/1089] more --- python/sglang/srt/managers/expert_distribution.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 95c04c463f1..441b96459e7 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -194,7 +194,9 @@ def append(self, single_pass_physical_count: torch.Tensor): self._physical_count += single_pass_physical_count def reset(self): - TODO + self._physical_count[...] = 0 def dump(self): - TODO + return dict( + physical_count=self._physical_count.tolist(), + ) From 165869650c566d32d262601fbd5d1f28e1381ee1 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:18:14 +0800 Subject: [PATCH 0116/1089] more --- python/sglang/srt/managers/expert_distribution.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 441b96459e7..10927efccb9 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -5,6 +5,7 @@ from typing import List, Type import torch +from sglang.srt.server_args import ServerArgs from sglang.srt.utils import Withable logger = logging.getLogger(__name__) @@ -86,8 +87,11 @@ def dump_record(self): class _SinglePassGatherer(ABC): @staticmethod - def init_new() -> "_SinglePassGatherer": - return TODO + def init_new(server_args: ServerArgs) -> "_SinglePassGatherer": + if server_args.enable_deepep_moe: + # TODO DeepEP low latency + return _DeepepNormalSinglePassGatherer() + return _LayerBasedSinglePassGatherer() def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor): pass From 6b595d37712a316a6eb28105d145fe8011e1a294 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:18:23 +0800 Subject: [PATCH 0117/1089] more --- python/sglang/srt/managers/expert_distribution.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 10927efccb9..3d103ddd60b 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -19,7 +19,7 @@ class _ExpertDistributionRecorder: def __init__(self): self._recording = False self._current_layer_idx = Withable() - self._single_pass_gatherer = _SinglePassGatherer.init_new() + self._single_pass_gatherer = _SinglePassGatherer.init_new(server_args) self._accumulator = _Accumulator.init_new() def with_current_layer(self, layer_idx): From b84e2ca090801d9265ae2a483d7be252cc908f8b Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:18:44 +0800 Subject: [PATCH 0118/1089] more --- python/sglang/srt/managers/expert_distribution.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 3d103ddd60b..415c2d8fa37 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -162,7 +162,9 @@ def init_new() -> "_Accumulator": @staticmethod def get_class() -> Type["_Accumulator"]: - return TODO + if TODO: + return _DetailAccumulator + return _StatAccumulator def append(self, single_pass_physical_count: torch.Tensor): raise NotImplementedError From 29b8f4a5a61f6c9e9c5ab809bffc7dd2667b8210 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:19:08 +0800 Subject: [PATCH 0119/1089] more --- python/sglang/srt/managers/expert_distribution.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 415c2d8fa37..e76e7ec2879 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -6,7 +6,7 @@ import torch from sglang.srt.server_args import ServerArgs -from sglang.srt.utils import Withable +from sglang.srt.utils import Withable, get_bool_env_var logger = logging.getLogger(__name__) @@ -162,7 +162,7 @@ def init_new() -> "_Accumulator": @staticmethod def get_class() -> Type["_Accumulator"]: - if TODO: + if get_bool_env_var("SGLANG_EXPERT_DISTRIBUTION_RECORDER_DETAIL"): return _DetailAccumulator return _StatAccumulator From 7a5f5444aecaaeab5b3f896d2ed71420aed52a9f Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:20:01 +0800 Subject: [PATCH 0120/1089] more --- python/sglang/srt/managers/io_struct.py | 2 +- python/sglang/srt/managers/scheduler.py | 30 ++++++++++++------------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index e25a8f242c5..f51bdc72835 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -666,7 +666,7 @@ class ExpertDistributionReq(Enum): @dataclass class ExpertDistributionReqOutput: - pass + dump_output: Optional[Any] = None @dataclass diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 40e9bfef992..a49668c063d 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -32,8 +32,6 @@ import setproctitle import torch import zmq -from torch.distributed import barrier - from sglang.global_config import global_config from sglang.srt.configs.model_config import ModelConfig from sglang.srt.constrained.base_grammar_backend import create_grammar_backend @@ -130,6 +128,7 @@ suppress_other_loggers, ) from sglang.utils import TypeBasedDispatcher, get_exception_traceback +from torch.distributed import barrier logger = logging.getLogger(__name__) @@ -359,8 +358,8 @@ def __init__( 1.0, ) self.new_token_ratio_decay = ( - self.init_new_token_ratio - self.min_new_token_ratio - ) / global_config.default_new_token_ratio_decay_steps + self.init_new_token_ratio - self.min_new_token_ratio + ) / global_config.default_new_token_ratio_decay_steps self.new_token_ratio = self.init_new_token_ratio # Init watchdog thread @@ -1236,10 +1235,10 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: if ( self.lora_paths and len( - lora_set - | set([req.lora_path for req in adder.can_run_list]) - | set([req.lora_path]) - ) + lora_set + | set([req.lora_path for req in adder.can_run_list]) + | set([req.lora_path]) + ) > self.max_loras_per_batch ): self.running_batch.batch_is_full = True @@ -1264,9 +1263,9 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: self.running_batch.batch_is_full = len( adder.can_run_list ) > 0 or ( - self.running_batch is not None - and not self.running_batch.is_empty() - ) + self.running_batch is not None + and not self.running_batch.is_empty() + ) else: self.running_batch.batch_is_full = True break @@ -1470,8 +1469,8 @@ def prepare_dp_attn_batch(self, local_batch: ScheduleBatch): # We should have at least 1 token for sample in every case. max(extend_len - logprob_start_len, 1) for logprob_start_len, extend_len in zip( - local_batch.extend_logprob_start_lens, local_batch.extend_lens - ) + local_batch.extend_logprob_start_lens, local_batch.extend_lens + ) ] ) @@ -1908,15 +1907,16 @@ def stop_profile(self) -> None: ) def expert_distribution_handle(self, recv_req: ExpertDistributionReq): + dump_output = None if recv_req == ExpertDistributionReq.START_RECORD: expert_distribution_recorder.start_record() elif recv_req == ExpertDistributionReq.STOP_RECORD: expert_distribution_recorder.stop_record() elif recv_req == ExpertDistributionReq.DUMP_RECORD: - expert_distribution_recorder.dump_record() + dump_output = expert_distribution_recorder.dump_record() else: raise ValueError("Unrecognized ExpertDistributionReq value") - return ExpertDistributionReqOutput() + return ExpertDistributionReqOutput(dump_output=dump_output) def open_session(self, recv_req: OpenSessionReqInput): # handle error From de3d02a5ba79f1b6f48042022bb06d207f0d1e58 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:20:29 +0800 Subject: [PATCH 0121/1089] more --- python/sglang/srt/managers/tokenizer_manager.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 1b6ad816f2b..76eb4987776 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -654,6 +654,7 @@ async def stop_expert_distribution_record(self): async def dump_expert_distribution_record(self): await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD) + return TODO async def update_weights_from_disk( self, From f806ac8aec8e3cc19657371092827fb0c1fd0592 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:21:08 +0800 Subject: [PATCH 0122/1089] more --- python/sglang/srt/managers/tokenizer_manager.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 76eb4987776..a9702aa4f9c 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -45,7 +45,6 @@ import zmq import zmq.asyncio from fastapi import BackgroundTasks - from sglang.srt.aio_rwlock import RWLock from sglang.srt.configs.model_config import ModelConfig from sglang.srt.disaggregation.conn import KVBootstrapServer @@ -653,7 +652,9 @@ async def stop_expert_distribution_record(self): await self.expert_distribution_communicator(ExpertDistributionReq.STOP_RECORD) async def dump_expert_distribution_record(self): - await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD) + raw_outputs: List[ExpertDistributionReqOutput] = await self.expert_distribution_communicator( + ExpertDistributionReq.DUMP_RECORD) + raw_outputs = [output.dump_output for output in raw_outputs] return TODO async def update_weights_from_disk( @@ -958,8 +959,8 @@ def _handle_batch_output( elif isinstance(recv_obj, BatchTokenIDOut): if self.server_args.stream_output and state.obj.stream: output_token_ids = recv_obj.output_ids[i][ - state.last_output_offset : - ] + state.last_output_offset: + ] state.last_output_offset = len(recv_obj.output_ids[i]) else: output_token_ids = recv_obj.output_ids[i] From b18639c4df9b0a6fee91258cf08c15b700faf84e Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:22:14 +0800 Subject: [PATCH 0123/1089] more --- python/sglang/srt/managers/expert_distribution.py | 6 +++++- python/sglang/srt/managers/tokenizer_manager.py | 4 ++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index e76e7ec2879..faed84db427 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -2,7 +2,7 @@ from abc import ABC from contextlib import contextmanager from copy import deepcopy -from typing import List, Type +from typing import List, Type, Any import torch from sglang.srt.server_args import ServerArgs @@ -83,6 +83,10 @@ def dump_record(self): expert_distribution_recorder = _ExpertDistributionRecorder() +def convert_dump_to_logical(physical_dumps: List[Any]): + return TODO + + # --------------------------------------- SinglePassGatherer ----------------------------------------- class _SinglePassGatherer(ABC): diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index a9702aa4f9c..9531cc038ea 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -50,6 +50,7 @@ from sglang.srt.disaggregation.conn import KVBootstrapServer from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer +from sglang.srt.managers import expert_distribution from sglang.srt.managers.io_struct import ( AbortReq, BatchEmbeddingOut, @@ -654,8 +655,7 @@ async def stop_expert_distribution_record(self): async def dump_expert_distribution_record(self): raw_outputs: List[ExpertDistributionReqOutput] = await self.expert_distribution_communicator( ExpertDistributionReq.DUMP_RECORD) - raw_outputs = [output.dump_output for output in raw_outputs] - return TODO + return expert_distribution.convert_dump_to_logical([output.dump_output for output in raw_outputs]) async def update_weights_from_disk( self, From c712bbdd15822ccf433c0974ea61f44fc05bd643 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:22:52 +0800 Subject: [PATCH 0124/1089] more --- python/sglang/srt/managers/expert_distribution.py | 15 ++++++++++++--- python/sglang/srt/managers/tokenizer_manager.py | 2 +- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index faed84db427..c381e434fc1 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -83,8 +83,8 @@ def dump_record(self): expert_distribution_recorder = _ExpertDistributionRecorder() -def convert_dump_to_logical(physical_dumps: List[Any]): - return TODO +def convert_dumps_to_logical(physical_dumps: List[Any]): + return _Accumulator.get_class().convert_dumps_to_logical(physical_dumps) # --------------------------------------- SinglePassGatherer ----------------------------------------- @@ -170,6 +170,9 @@ def get_class() -> Type["_Accumulator"]: return _DetailAccumulator return _StatAccumulator + @classmethod + def convert_dumps_to_logical(cls, physical_dumps: List[Any]): + raise NotImplementedError def append(self, single_pass_physical_count: torch.Tensor): raise NotImplementedError @@ -181,9 +184,12 @@ def dump(self): class _DetailAccumulator(_Accumulator): + @classmethod + def convert_dumps_to_logical(cls, physical_dumps: List[Any]): + TODO + def __init__(self): self._records = [] - def append(self, single_pass_physical_count: torch.Tensor): self._records.append(dict( physical_count=single_pass_physical_count.tolist(), @@ -197,6 +203,9 @@ def dump(self): class _StatAccumulator(_Accumulator): + @classmethod + def convert_dumps_to_logical(cls, physical_dumps: List[Any]): + TODO def __init__(self): self._physical_count = torch.zeros((num_layers, num_local_physical_experts)) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 9531cc038ea..d76bdd3c23f 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -655,7 +655,7 @@ async def stop_expert_distribution_record(self): async def dump_expert_distribution_record(self): raw_outputs: List[ExpertDistributionReqOutput] = await self.expert_distribution_communicator( ExpertDistributionReq.DUMP_RECORD) - return expert_distribution.convert_dump_to_logical([output.dump_output for output in raw_outputs]) + return expert_distribution.convert_dumps_to_logical([output.dump_output for output in raw_outputs]) async def update_weights_from_disk( self, From 05305b2ea2054f81ccf3b5c10f869e1f90a27062 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:22:59 +0800 Subject: [PATCH 0125/1089] more --- python/sglang/srt/managers/expert_distribution.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index c381e434fc1..1b9a3441a5d 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -173,6 +173,7 @@ def get_class() -> Type["_Accumulator"]: @classmethod def convert_dumps_to_logical(cls, physical_dumps: List[Any]): raise NotImplementedError + def append(self, single_pass_physical_count: torch.Tensor): raise NotImplementedError @@ -190,6 +191,7 @@ def convert_dumps_to_logical(cls, physical_dumps: List[Any]): def __init__(self): self._records = [] + def append(self, single_pass_physical_count: torch.Tensor): self._records.append(dict( physical_count=single_pass_physical_count.tolist(), @@ -206,6 +208,7 @@ class _StatAccumulator(_Accumulator): @classmethod def convert_dumps_to_logical(cls, physical_dumps: List[Any]): TODO + def __init__(self): self._physical_count = torch.zeros((num_layers, num_local_physical_experts)) From 14fdd55710c92de882359b230d7c8492c44527d6 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:23:49 +0800 Subject: [PATCH 0126/1089] more --- python/sglang/srt/managers/expert_distribution.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 1b9a3441a5d..8540c1cbdfd 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -187,7 +187,12 @@ def dump(self): class _DetailAccumulator(_Accumulator): @classmethod def convert_dumps_to_logical(cls, physical_dumps: List[Any]): - TODO + # Do not convert to logical since we want all details + return [ + record + for physical_dump in physical_dumps + for record in physical_dump + ] def __init__(self): self._records = [] From b392cb972aa9b8556b5297fdfd133ec34309323c Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:24:03 +0800 Subject: [PATCH 0127/1089] more --- python/sglang/srt/managers/expert_distribution.py | 10 +++++----- python/sglang/srt/managers/tokenizer_manager.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 8540c1cbdfd..6e5dfc65e79 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -83,8 +83,8 @@ def dump_record(self): expert_distribution_recorder = _ExpertDistributionRecorder() -def convert_dumps_to_logical(physical_dumps: List[Any]): - return _Accumulator.get_class().convert_dumps_to_logical(physical_dumps) +def postprocess_dumps(physical_dumps: List[Any]): + return _Accumulator.get_class().postprocess_dumps(physical_dumps) # --------------------------------------- SinglePassGatherer ----------------------------------------- @@ -171,7 +171,7 @@ def get_class() -> Type["_Accumulator"]: return _StatAccumulator @classmethod - def convert_dumps_to_logical(cls, physical_dumps: List[Any]): + def postprocess_dumps(cls, physical_dumps: List[Any]): raise NotImplementedError def append(self, single_pass_physical_count: torch.Tensor): @@ -186,7 +186,7 @@ def dump(self): class _DetailAccumulator(_Accumulator): @classmethod - def convert_dumps_to_logical(cls, physical_dumps: List[Any]): + def postprocess_dumps(cls, physical_dumps: List[Any]): # Do not convert to logical since we want all details return [ record @@ -211,7 +211,7 @@ def dump(self): class _StatAccumulator(_Accumulator): @classmethod - def convert_dumps_to_logical(cls, physical_dumps: List[Any]): + def postprocess_dumps(cls, physical_dumps: List[Any]): TODO def __init__(self): diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index d76bdd3c23f..03f477719d3 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -655,7 +655,7 @@ async def stop_expert_distribution_record(self): async def dump_expert_distribution_record(self): raw_outputs: List[ExpertDistributionReqOutput] = await self.expert_distribution_communicator( ExpertDistributionReq.DUMP_RECORD) - return expert_distribution.convert_dumps_to_logical([output.dump_output for output in raw_outputs]) + return expert_distribution.postprocess_dumps([output.dump_output for output in raw_outputs]) async def update_weights_from_disk( self, From 594b7510425ed9994e59c01cf19bb9c913cba9a1 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:24:25 +0800 Subject: [PATCH 0128/1089] more --- python/sglang/srt/managers/expert_distribution.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 6e5dfc65e79..f20ab92f5d4 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -212,7 +212,8 @@ def dump(self): class _StatAccumulator(_Accumulator): @classmethod def postprocess_dumps(cls, physical_dumps: List[Any]): - TODO + logical_count = torch.zeros((num_layers, num_logical_experts)) + return dict(logical_count=logical_count) def __init__(self): self._physical_count = torch.zeros((num_layers, num_local_physical_experts)) From f09eb0bda0488bfd54b62f61944bed783929ec0c Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:25:15 +0800 Subject: [PATCH 0129/1089] more --- python/sglang/srt/managers/expert_distribution.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index f20ab92f5d4..db5867bbd69 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -213,6 +213,11 @@ class _StatAccumulator(_Accumulator): @classmethod def postprocess_dumps(cls, physical_dumps: List[Any]): logical_count = torch.zeros((num_layers, num_logical_experts)) + # Most naive implementation, can optimize if it is bottleneck + for physical_dump in physical_dumps: + for layer_index in range(num_layers): + for physical_expert_index in range(num_local_physical_experts): + TODO return dict(logical_count=logical_count) def __init__(self): From 2c73330de65e3f5c80202798e02687f1cb52cee3 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:25:33 +0800 Subject: [PATCH 0130/1089] more --- python/sglang/srt/managers/expert_distribution.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index db5867bbd69..cb50602413c 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -217,7 +217,9 @@ def postprocess_dumps(cls, physical_dumps: List[Any]): for physical_dump in physical_dumps: for layer_index in range(num_layers): for physical_expert_index in range(num_local_physical_experts): - TODO + logical_expert_index = TODO + logical_count[layer_index, logical_expert_index] += physical_dump[ + layer_index, physical_expert_index] return dict(logical_count=logical_count) def __init__(self): From a6cc800f7397bd4c34c0082f6777a886922b6385 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:25:48 +0800 Subject: [PATCH 0131/1089] more --- python/sglang/srt/managers/expert_distribution.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index cb50602413c..f641f585ad9 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -216,10 +216,10 @@ def postprocess_dumps(cls, physical_dumps: List[Any]): # Most naive implementation, can optimize if it is bottleneck for physical_dump in physical_dumps: for layer_index in range(num_layers): - for physical_expert_index in range(num_local_physical_experts): + for local_physical_expert_index in range(num_local_physical_experts): logical_expert_index = TODO logical_count[layer_index, logical_expert_index] += physical_dump[ - layer_index, physical_expert_index] + layer_index, local_physical_expert_index] return dict(logical_count=logical_count) def __init__(self): From 5dfc75b34631e2d538e3ae8de63c256b4d0ef25a Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:27:31 +0800 Subject: [PATCH 0132/1089] more --- python/sglang/srt/managers/expert_distribution.py | 10 +++++----- python/sglang/srt/managers/tokenizer_manager.py | 3 ++- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index f641f585ad9..ded1e25e407 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -83,8 +83,8 @@ def dump_record(self): expert_distribution_recorder = _ExpertDistributionRecorder() -def postprocess_dumps(physical_dumps: List[Any]): - return _Accumulator.get_class().postprocess_dumps(physical_dumps) +def postprocess_dumps(physical_dumps: List[Any], physical_to_logical_map: torch.Tensor): + return _Accumulator.get_class().postprocess_dumps(physical_dumps, physical_to_logical_map) # --------------------------------------- SinglePassGatherer ----------------------------------------- @@ -171,7 +171,7 @@ def get_class() -> Type["_Accumulator"]: return _StatAccumulator @classmethod - def postprocess_dumps(cls, physical_dumps: List[Any]): + def postprocess_dumps(cls, physical_dumps: List[Any], physical_to_logical_map: torch.Tensor): raise NotImplementedError def append(self, single_pass_physical_count: torch.Tensor): @@ -186,7 +186,7 @@ def dump(self): class _DetailAccumulator(_Accumulator): @classmethod - def postprocess_dumps(cls, physical_dumps: List[Any]): + def postprocess_dumps(cls, physical_dumps: List[Any], physical_to_logical_map: torch.Tensor): # Do not convert to logical since we want all details return [ record @@ -211,7 +211,7 @@ def dump(self): class _StatAccumulator(_Accumulator): @classmethod - def postprocess_dumps(cls, physical_dumps: List[Any]): + def postprocess_dumps(cls, physical_dumps: List[Any], physical_to_logical_map: torch.Tensor): logical_count = torch.zeros((num_layers, num_logical_experts)) # Most naive implementation, can optimize if it is bottleneck for physical_dump in physical_dumps: diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 03f477719d3..bbbc8daa384 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -655,7 +655,8 @@ async def stop_expert_distribution_record(self): async def dump_expert_distribution_record(self): raw_outputs: List[ExpertDistributionReqOutput] = await self.expert_distribution_communicator( ExpertDistributionReq.DUMP_RECORD) - return expert_distribution.postprocess_dumps([output.dump_output for output in raw_outputs]) + return expert_distribution.postprocess_dumps([output.dump_output for output in raw_outputs], + physical_to_logical_map=TODO) async def update_weights_from_disk( self, From e13b53678604217d3d266acde74b9f8deabe6360 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:28:18 +0800 Subject: [PATCH 0133/1089] more --- python/sglang/srt/managers/expert_distribution.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index ded1e25e407..a917102b5a0 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -217,7 +217,8 @@ def postprocess_dumps(cls, physical_dumps: List[Any], physical_to_logical_map: t for physical_dump in physical_dumps: for layer_index in range(num_layers): for local_physical_expert_index in range(num_local_physical_experts): - logical_expert_index = TODO + global_physical_expert_index = TODO + local_physical_expert_index + logical_expert_index = physical_to_logical_map[layer_index, global_physical_expert_index] logical_count[layer_index, logical_expert_index] += physical_dump[ layer_index, local_physical_expert_index] return dict(logical_count=logical_count) From a47caa524128ef88453f8297ad480fd4429c1f78 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:28:38 +0800 Subject: [PATCH 0134/1089] more --- python/sglang/srt/managers/expert_distribution.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index a917102b5a0..e62a992033b 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -217,9 +217,10 @@ def postprocess_dumps(cls, physical_dumps: List[Any], physical_to_logical_map: t for physical_dump in physical_dumps: for layer_index in range(num_layers): for local_physical_expert_index in range(num_local_physical_experts): - global_physical_expert_index = TODO + local_physical_expert_index + global_physical_expert_index = num_local_physical_experts * physical_dump[ + 'rank'] + local_physical_expert_index logical_expert_index = physical_to_logical_map[layer_index, global_physical_expert_index] - logical_count[layer_index, logical_expert_index] += physical_dump[ + logical_count[layer_index, logical_expert_index] += physical_dump['physical_count'][ layer_index, local_physical_expert_index] return dict(logical_count=logical_count) From 4c662b688af8f9efbace936ac4cda5090dd66c03 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:29:43 +0800 Subject: [PATCH 0135/1089] more --- python/sglang/srt/managers/expert_distribution.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index e62a992033b..303aafda6b0 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -199,6 +199,7 @@ def __init__(self): def append(self, single_pass_physical_count: torch.Tensor): self._records.append(dict( + rank=TODO, physical_count=single_pass_physical_count.tolist(), )) @@ -235,5 +236,6 @@ def reset(self): def dump(self): return dict( + rank=TODO, physical_count=self._physical_count.tolist(), ) From 4387c3f5c8940b10d402dd29550f2dfe6bce3a98 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:30:46 +0800 Subject: [PATCH 0136/1089] more --- python/sglang/srt/managers/expert_distribution.py | 15 ++++++++------- python/sglang/srt/model_executor/model_runner.py | 7 +++---- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 303aafda6b0..01d827d78c1 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -26,15 +26,15 @@ def with_current_layer(self, layer_idx): return self._current_layer_idx.with_value(layer_idx) @contextmanager - def with_forward_pass(self): + def with_forward_pass(self, forward_pass_id: int): try: yield finally: - self._on_forward_pass_end() + self._on_forward_pass_end(forward_pass_id) - def _on_forward_pass_end(self): + def _on_forward_pass_end(self, forward_pass_id: int): single_pass_physical_count = self._single_pass_gatherer.collect() - self._accumulator.append(single_pass_physical_count) + self._accumulator.append(forward_pass_id, single_pass_physical_count) self._single_pass_gatherer.reset() def on_select_experts(self, topk_ids: torch.Tensor): @@ -174,7 +174,7 @@ def get_class() -> Type["_Accumulator"]: def postprocess_dumps(cls, physical_dumps: List[Any], physical_to_logical_map: torch.Tensor): raise NotImplementedError - def append(self, single_pass_physical_count: torch.Tensor): + def append(self, forward_pass_id: int, single_pass_physical_count: torch.Tensor): raise NotImplementedError def reset(self): @@ -197,8 +197,9 @@ def postprocess_dumps(cls, physical_dumps: List[Any], physical_to_logical_map: t def __init__(self): self._records = [] - def append(self, single_pass_physical_count: torch.Tensor): + def append(self, forward_pass_id: int, single_pass_physical_count: torch.Tensor): self._records.append(dict( + forward_pass_id=forward_pass_id, rank=TODO, physical_count=single_pass_physical_count.tolist(), )) @@ -228,7 +229,7 @@ def postprocess_dumps(cls, physical_dumps: List[Any], physical_to_logical_map: t def __init__(self): self._physical_count = torch.zeros((num_layers, num_local_physical_experts)) - def append(self, single_pass_physical_count: torch.Tensor): + def append(self, forward_pass_id: int, single_pass_physical_count: torch.Tensor): self._physical_count += single_pass_physical_count def reset(self): diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index bb3a679295b..0b1d3d65b77 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -24,7 +24,6 @@ import torch import torch.distributed as dist - from sglang.srt.configs.device_config import DeviceConfig from sglang.srt.configs.load_config import LoadConfig from sglang.srt.configs.model_config import AttentionArch, ModelConfig @@ -161,7 +160,7 @@ def __init__( ) # CPU offload - set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3)) + set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024 ** 3)) # Get memory before model loading min_per_gpu_memory = self.init_torch_distributed() @@ -895,7 +894,7 @@ def init_double_sparsity_channel_config(self, selected_channel): key = "model.layers." + str(i) + ".self_attn" + selected_channel self.sorted_channels.append( torch.tensor(channel_config[key])[ - :, : self.server_args.ds_heavy_channel_num + :, : self.server_args.ds_heavy_channel_num ] .contiguous() .cuda() @@ -972,7 +971,7 @@ def forward_idle(self, forward_batch: ForwardBatch): def forward( self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False ) -> LogitsProcessorOutput: - with expert_distribution_recorder.with_forward_pass(): + with expert_distribution_recorder.with_forward_pass(forward_pass_id): return self._forward_raw(forward_batch, skip_attn_backend_init) def _forward_raw(self, forward_batch: ForwardBatch, skip_attn_backend_init: bool) -> LogitsProcessorOutput: From e86f54fa7c5487519bc0c8a012bc3f309be2d44b Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:31:35 +0800 Subject: [PATCH 0137/1089] more --- python/sglang/srt/model_executor/model_runner.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 0b1d3d65b77..f21f4cf3e1b 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -124,6 +124,8 @@ def __init__( self.req_to_token_pool = req_to_token_pool self.token_to_kv_pool_allocator = token_to_kv_pool_allocator + self.forward_pass_id = 0 + # Model-specific adjustment self.model_specific_adjustment() @@ -971,7 +973,8 @@ def forward_idle(self, forward_batch: ForwardBatch): def forward( self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False ) -> LogitsProcessorOutput: - with expert_distribution_recorder.with_forward_pass(forward_pass_id): + self.forward_pass_id += 1 + with expert_distribution_recorder.with_forward_pass(self.forward_pass_id): return self._forward_raw(forward_batch, skip_attn_backend_init) def _forward_raw(self, forward_batch: ForwardBatch, skip_attn_backend_init: bool) -> LogitsProcessorOutput: From 4d74806f675709c99943d87d7df2af5c8513f94f Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:32:13 +0800 Subject: [PATCH 0138/1089] more --- python/sglang/srt/entrypoints/http_server.py | 41 +++++++++----------- 1 file changed, 19 insertions(+), 22 deletions(-) diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 1f93b475c26..94e59978221 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -231,7 +231,7 @@ async def generate_request(obj: GenerateReqInput, request: Request): async def stream_results() -> AsyncIterator[bytes]: try: async for out in _global_state.tokenizer_manager.generate_request( - obj, request + obj, request ): yield b"data: " + orjson.dumps( out, option=orjson.OPT_NON_STR_KEYS @@ -313,7 +313,7 @@ async def flush_cache(): _global_state.tokenizer_manager.flush_cache() return Response( content="Cache flushed.\nPlease check backend logs for more details. " - "(When there are running or waiting requests, the operation will not be performed.)\n", + "(When there are running or waiting requests, the operation will not be performed.)\n", status_code=200, ) @@ -366,11 +366,8 @@ async def stop_expert_distribution_record_async(): @app.api_route("/dump_expert_distribution_record", methods=["GET", "POST"]) async def dump_expert_distribution_record_async(): """Dump expert distribution record.""" - await _global_state.tokenizer_manager.dump_expert_distribution_record() - return Response( - content="Dump expert distribution record.\n", - status_code=200, - ) + content = await _global_state.tokenizer_manager.dump_expert_distribution_record() + return ORJSONResponse(content, status_code=200) @app.post("/update_weights_from_disk") @@ -398,7 +395,7 @@ async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: R @app.post("/init_weights_update_group") async def init_weights_update_group( - obj: InitWeightsUpdateGroupReqInput, request: Request + obj: InitWeightsUpdateGroupReqInput, request: Request ): """Initialize the parameter update group.""" success, message = await _global_state.tokenizer_manager.init_weights_update_group( @@ -413,7 +410,7 @@ async def init_weights_update_group( @app.post("/update_weights_from_distributed") async def update_weights_from_distributed( - obj: UpdateWeightsFromDistributedReqInput, request: Request + obj: UpdateWeightsFromDistributedReqInput, request: Request ): """Update model parameter from distributed online.""" success, message = ( @@ -443,7 +440,7 @@ async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request): @app.api_route("/release_memory_occupation", methods=["GET", "POST"]) async def release_memory_occupation( - obj: ReleaseMemoryOccupationReqInput, request: Request + obj: ReleaseMemoryOccupationReqInput, request: Request ): """Release GPU memory occupation temporarily.""" try: @@ -454,7 +451,7 @@ async def release_memory_occupation( @app.api_route("/resume_memory_occupation", methods=["GET", "POST"]) async def resume_memory_occupation( - obj: ResumeMemoryOccupationReqInput, request: Request + obj: ResumeMemoryOccupationReqInput, request: Request ): """Resume GPU memory occupation.""" try: @@ -637,10 +634,10 @@ async def vertex_generate(vertex_req: VertexGenerateReqInput, raw_request: Reque ] break image_data = [ - instance.get("image_data") - for instance in vertex_req.instances - if instance.get("image_data") is not None - ] or None + instance.get("image_data") + for instance in vertex_req.instances + if instance.get("image_data") is not None + ] or None req = GenerateReqInput( **inputs, image_data=image_data, @@ -657,9 +654,9 @@ def _create_error_response(e): def launch_server( - server_args: ServerArgs, - pipe_finish_writer: Optional[multiprocessing.connection.Connection] = None, - launch_callback: Optional[Callable[[], None]] = None, + server_args: ServerArgs, + pipe_finish_writer: Optional[multiprocessing.connection.Connection] = None, + launch_callback: Optional[Callable[[], None]] = None, ): """ Launch SRT (SGLang Runtime) Server. @@ -724,10 +721,10 @@ def launch_server( def _wait_and_warmup( - server_args: ServerArgs, - pipe_finish_writer: Optional[multiprocessing.connection.Connection], - image_token_text: str, - launch_callback: Optional[Callable[[], None]] = None, + server_args: ServerArgs, + pipe_finish_writer: Optional[multiprocessing.connection.Connection], + image_token_text: str, + launch_callback: Optional[Callable[[], None]] = None, ): headers = {} url = server_args.url() From 925c0c44046b0cc72bcc70c516eb48783760dd6a Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:35:43 +0800 Subject: [PATCH 0139/1089] more --- python/sglang/srt/managers/expert_distribution.py | 6 +++--- python/sglang/srt/model_executor/model_runner.py | 6 +++++- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 01d827d78c1..4cb5228ab4b 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -2,7 +2,7 @@ from abc import ABC from contextlib import contextmanager from copy import deepcopy -from typing import List, Type, Any +from typing import List, Type, Any, Optional import torch from sglang.srt.server_args import ServerArgs @@ -13,7 +13,7 @@ # --------------------------------------- Entrypoint ----------------------------------------- -class _ExpertDistributionRecorder: +class ExpertDistributionRecorder: """Global expert distribution recording""" def __init__(self): @@ -80,7 +80,7 @@ def dump_record(self): return output -expert_distribution_recorder = _ExpertDistributionRecorder() +global_expert_distribution_recorder: Optional[ExpertDistributionRecorder] = None def postprocess_dumps(physical_dumps: List[Any], physical_to_logical_map: torch.Tensor): diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index f21f4cf3e1b..c931ba089c2 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -44,7 +44,7 @@ from sglang.srt.layers.sampler import Sampler from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model from sglang.srt.lora.lora_manager import LoRAManager -from sglang.srt.managers.expert_distribution import expert_distribution_recorder +from sglang.srt.managers.expert_distribution import expert_distribution_recorder, ExpertDistributionRecorder from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.mem_cache.memory_pool import ( DoubleSparseTokenToKVPool, @@ -215,6 +215,10 @@ def initialize(self, min_per_gpu_memory: float): if self.spec_algorithm.is_eagle3() and not self.is_draft_worker: self.model.set_eagle3_layers_to_capture() + self.expert_distribution_recorder = ExpertDistributionRecorder() + global global_expert_distribution_recorder + global_expert_distribution_recorder = self.expert_distribution_recorder + def model_specific_adjustment(self): server_args = self.server_args From ef7b83e5b0310cc6648167089cde674c8e8bd29a Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:36:35 +0800 Subject: [PATCH 0140/1089] more --- python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py | 4 ++-- python/sglang/srt/layers/moe/topk.py | 4 ++-- python/sglang/srt/managers/scheduler.py | 8 ++++---- python/sglang/srt/model_executor/model_runner.py | 2 +- python/sglang/srt/models/deepseek_v2.py | 4 ++-- python/sglang/srt/models/qwen2_moe.py | 4 ++-- 6 files changed, 13 insertions(+), 13 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index 4bf33649cea..42fa548f77f 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -1,4 +1,4 @@ -from sglang.srt.managers.expert_distribution import expert_distribution_recorder +from sglang.srt.managers.expert_distribution import global_expert_distribution_recorder try: from deep_ep import Buffer @@ -198,7 +198,7 @@ def dispatch( handle, event, ) = self.dispatch_normal(hidden_states, topk_idx, topk_weights, num_experts) - expert_distribution_recorder.on_deepep_dispatch_normal(num_recv_tokens_per_expert_list) + global_expert_distribution_recorder.on_deepep_dispatch_normal(num_recv_tokens_per_expert_list) self.tokens_per_expert = torch.tensor( num_recv_tokens_per_expert_list, device=hidden_states.device, diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 247901d64f3..170960a4f83 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -17,7 +17,7 @@ import torch import torch.nn.functional as F -from sglang.srt.managers.expert_distribution import expert_distribution_recorder +from sglang.srt.managers.expert_distribution import global_expert_distribution_recorder from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip _is_cuda = is_cuda() @@ -253,6 +253,6 @@ def select_experts( renormalize=renormalize, ) - expert_distribution_recorder.on_select_experts(topk_ids=topk_ids) + global_expert_distribution_recorder.on_select_experts(topk_ids=topk_ids) return topk_weights, topk_ids diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index a49668c063d..eb3023a808a 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -51,7 +51,7 @@ from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.layers.dp_attention import compute_dp_attention_world_info from sglang.srt.layers.logits_processor import LogitsProcessorOutput -from sglang.srt.managers.expert_distribution import expert_distribution_recorder +from sglang.srt.managers.expert_distribution import global_expert_distribution_recorder from sglang.srt.managers.io_struct import ( AbortReq, CloseSessionReqInput, @@ -1909,11 +1909,11 @@ def stop_profile(self) -> None: def expert_distribution_handle(self, recv_req: ExpertDistributionReq): dump_output = None if recv_req == ExpertDistributionReq.START_RECORD: - expert_distribution_recorder.start_record() + global_expert_distribution_recorder.start_record() elif recv_req == ExpertDistributionReq.STOP_RECORD: - expert_distribution_recorder.stop_record() + global_expert_distribution_recorder.stop_record() elif recv_req == ExpertDistributionReq.DUMP_RECORD: - dump_output = expert_distribution_recorder.dump_record() + dump_output = global_expert_distribution_recorder.dump_record() else: raise ValueError("Unrecognized ExpertDistributionReq value") return ExpertDistributionReqOutput(dump_output=dump_output) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index c931ba089c2..b48bb7b82bd 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -44,7 +44,7 @@ from sglang.srt.layers.sampler import Sampler from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model from sglang.srt.lora.lora_manager import LoRAManager -from sglang.srt.managers.expert_distribution import expert_distribution_recorder, ExpertDistributionRecorder +from sglang.srt.managers.expert_distribution import global_expert_distribution_recorder, ExpertDistributionRecorder from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.mem_cache.memory_pool import ( DoubleSparseTokenToKVPool, diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 6c0b5a3d124..5e0755ad6ca 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -66,7 +66,7 @@ ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.managers.expert_distribution import expert_distribution_recorder +from sglang.srt.managers.expert_distribution import global_expert_distribution_recorder from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_loader.weight_utils import default_weight_loader @@ -1270,7 +1270,7 @@ def forward( residual = None for i in range(len(self.layers)): - with expert_distribution_recorder.with_current_layer(i): + with global_expert_distribution_recorder.with_current_layer(i): layer = self.layers[i] hidden_states, residual = layer( positions, hidden_states, forward_batch, residual diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 57f7d111650..e110ff49455 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -41,7 +41,7 @@ ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.managers.expert_distribution import expert_distribution_recorder +from sglang.srt.managers.expert_distribution import global_expert_distribution_recorder from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import add_prefix @@ -366,7 +366,7 @@ def forward( hidden_states = input_embeds residual = None for i in range(len(self.layers)): - with expert_distribution_recorder.with_current_layer(i): + with global_expert_distribution_recorder.with_current_layer(i): layer = self.layers[i] hidden_states, residual = layer( positions, hidden_states, forward_batch, residual From d6b7aa96e34a01dbb676aa17d6cd6b47c48a2d82 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:36:57 +0800 Subject: [PATCH 0141/1089] more --- python/sglang/srt/managers/expert_distribution.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 4cb5228ab4b..1171b5a70ba 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -80,6 +80,7 @@ def dump_record(self): return output +# Put global args for easy access, just like `global_server_args_dict` global_expert_distribution_recorder: Optional[ExpertDistributionRecorder] = None From db7222e53807991075947222a4318eaf43e8a690 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:37:21 +0800 Subject: [PATCH 0142/1089] more --- python/sglang/srt/managers/expert_distribution.py | 2 +- python/sglang/srt/model_executor/model_runner.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 1171b5a70ba..da864e0fefa 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -16,7 +16,7 @@ class ExpertDistributionRecorder: """Global expert distribution recording""" - def __init__(self): + def __init__(self, server_args: ServerArgs): self._recording = False self._current_layer_idx = Withable() self._single_pass_gatherer = _SinglePassGatherer.init_new(server_args) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index b48bb7b82bd..edcdaa45de3 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -215,7 +215,7 @@ def initialize(self, min_per_gpu_memory: float): if self.spec_algorithm.is_eagle3() and not self.is_draft_worker: self.model.set_eagle3_layers_to_capture() - self.expert_distribution_recorder = ExpertDistributionRecorder() + self.expert_distribution_recorder = ExpertDistributionRecorder(server_args) global global_expert_distribution_recorder global_expert_distribution_recorder = self.expert_distribution_recorder From 94e2ff20f0f18d55e236288c27561b015026d72c Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:38:52 +0800 Subject: [PATCH 0143/1089] more --- python/sglang/srt/managers/expert_distribution.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index da864e0fefa..4967e502af6 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -19,8 +19,9 @@ class ExpertDistributionRecorder: def __init__(self, server_args: ServerArgs): self._recording = False self._current_layer_idx = Withable() - self._single_pass_gatherer = _SinglePassGatherer.init_new(server_args) self._accumulator = _Accumulator.init_new() + self._single_pass_gatherers = {k: _SinglePassGatherer.init_new(server_args) for k in + self._accumulator.get_single_pass_gatherer_keys()} def with_current_layer(self, layer_idx): return self._current_layer_idx.with_value(layer_idx) @@ -90,6 +91,7 @@ def postprocess_dumps(physical_dumps: List[Any], physical_to_logical_map: torch. # --------------------------------------- SinglePassGatherer ----------------------------------------- + class _SinglePassGatherer(ABC): @staticmethod def init_new(server_args: ServerArgs) -> "_SinglePassGatherer": @@ -160,6 +162,9 @@ class _DeepepLowLatencySinglePassGatherer(_SinglePassGatherer): # --------------------------------------- Accumulator ----------------------------------------- +_SINGLE_PASS_GATHERER_KEY_PRIMARY = "primary" + + class _Accumulator(ABC): @staticmethod def init_new() -> "_Accumulator": @@ -171,6 +176,9 @@ def get_class() -> Type["_Accumulator"]: return _DetailAccumulator return _StatAccumulator + def get_single_pass_gatherer_keys(self): + return ["primary"] + @classmethod def postprocess_dumps(cls, physical_dumps: List[Any], physical_to_logical_map: torch.Tensor): raise NotImplementedError @@ -198,6 +206,11 @@ def postprocess_dumps(cls, physical_dumps: List[Any], physical_to_logical_map: t def __init__(self): self._records = [] + def get_single_pass_gatherer_keys(self): + if False: # TODO `server_args.enable_two_batch_overlap` + return [_SINGLE_PASS_GATHERER_KEY_PRIMARY, "child_a", "child_b"] + return super().get_single_pass_gatherer_keys() + def append(self, forward_pass_id: int, single_pass_physical_count: torch.Tensor): self._records.append(dict( forward_pass_id=forward_pass_id, From a92833fb87feae318238b4f96bfc6572df240b4f Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:39:07 +0800 Subject: [PATCH 0144/1089] more --- python/sglang/srt/managers/expert_distribution.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 4967e502af6..ffaaa06e0de 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -20,8 +20,10 @@ def __init__(self, server_args: ServerArgs): self._recording = False self._current_layer_idx = Withable() self._accumulator = _Accumulator.init_new() - self._single_pass_gatherers = {k: _SinglePassGatherer.init_new(server_args) for k in - self._accumulator.get_single_pass_gatherer_keys()} + self._single_pass_gatherers = { + k: _SinglePassGatherer.init_new(server_args) + for k in self._accumulator.get_single_pass_gatherer_keys() + } def with_current_layer(self, layer_idx): return self._current_layer_idx.with_value(layer_idx) From e3f4ac43e1e005270fa687db9359f88e56820515 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:40:30 +0800 Subject: [PATCH 0145/1089] more --- python/sglang/srt/managers/expert_distribution.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index ffaaa06e0de..f60ed2e6f18 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -43,13 +43,14 @@ def _on_forward_pass_end(self, forward_pass_id: int): def on_select_experts(self, topk_ids: torch.Tensor): if not self._recording: return - self._single_pass_gatherer.on_select_experts(layer_idx=self._current_layer_idx.value, topk_ids=topk_ids) + gatherer = self._single_pass_gatherers[self._accumulator.get_single_pass_gatherer_key()] + gatherer.on_select_experts(layer_idx=self._current_layer_idx.value, topk_ids=topk_ids) def on_deepep_dispatch_normal(self, num_recv_tokens_per_expert_list: List[int]): if not self._recording: return - self._single_pass_gatherer.on_deepep_dispatch_normal(self._current_layer_idx.value, - num_recv_tokens_per_expert_list) + gatherer = self._single_pass_gatherers[self._accumulator.get_single_pass_gatherer_key()] + gatherer.on_deepep_dispatch_normal(self._current_layer_idx.value, num_recv_tokens_per_expert_list) def _reset(self): """Reset the expert distribution recorder.""" From 34fc0424efb6597a2fcd7ee0ab789575dfcde223 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:41:03 +0800 Subject: [PATCH 0146/1089] more --- python/sglang/srt/managers/expert_distribution.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index f60ed2e6f18..cf79993360f 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -180,7 +180,10 @@ def get_class() -> Type["_Accumulator"]: return _StatAccumulator def get_single_pass_gatherer_keys(self): - return ["primary"] + return [_SINGLE_PASS_GATHERER_KEY_PRIMARY] + + def get_single_pass_gatherer_key(self, debug_name: str): + return _SINGLE_PASS_GATHERER_KEY_PRIMARY @classmethod def postprocess_dumps(cls, physical_dumps: List[Any], physical_to_logical_map: torch.Tensor): @@ -214,6 +217,11 @@ def get_single_pass_gatherer_keys(self): return [_SINGLE_PASS_GATHERER_KEY_PRIMARY, "child_a", "child_b"] return super().get_single_pass_gatherer_keys() + def get_single_pass_gatherer_key(self, debug_name: str): + if False: # TODO `server_args.enable_two_batch_overlap` + return debug_name + return super().get_single_pass_gatherer_key(debug_name) + def append(self, forward_pass_id: int, single_pass_physical_count: torch.Tensor): self._records.append(dict( forward_pass_id=forward_pass_id, From 5f8ad3287b7f7abf7611e985b1eaf01710c28fc2 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:41:26 +0800 Subject: [PATCH 0147/1089] more --- python/sglang/srt/managers/expert_distribution.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index cf79993360f..9ffb26d2555 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -57,7 +57,8 @@ def _reset(self): logger.info("Resetting ExpertDistributionRecorder...") self._recording = False assert self._current_layer_idx.value is None - self._single_pass_gatherer.reset() + for gatherer in self._single_pass_gatherers.values(): + gatherer.reset() self._accumulator.reset() def start_record(self): From 1e48c81447920bfa7a65c2efaf19a4c7e2418aef Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:41:59 +0800 Subject: [PATCH 0148/1089] more --- python/sglang/srt/managers/expert_distribution.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 9ffb26d2555..a353fc8cb52 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -36,9 +36,10 @@ def with_forward_pass(self, forward_pass_id: int): self._on_forward_pass_end(forward_pass_id) def _on_forward_pass_end(self, forward_pass_id: int): - single_pass_physical_count = self._single_pass_gatherer.collect() - self._accumulator.append(forward_pass_id, single_pass_physical_count) - self._single_pass_gatherer.reset() + for gatherer_key, gatherer in self._single_pass_gatherers.items(): + single_pass_physical_count = gatherer.collect() + self._accumulator.append(forward_pass_id, gatherer_key, single_pass_physical_count) + gatherer.reset() def on_select_experts(self, topk_ids: torch.Tensor): if not self._recording: @@ -190,7 +191,7 @@ def get_single_pass_gatherer_key(self, debug_name: str): def postprocess_dumps(cls, physical_dumps: List[Any], physical_to_logical_map: torch.Tensor): raise NotImplementedError - def append(self, forward_pass_id: int, single_pass_physical_count: torch.Tensor): + def append(self, forward_pass_id: int, gatherer_key: str, single_pass_physical_count: torch.Tensor): raise NotImplementedError def reset(self): @@ -223,7 +224,7 @@ def get_single_pass_gatherer_key(self, debug_name: str): return debug_name return super().get_single_pass_gatherer_key(debug_name) - def append(self, forward_pass_id: int, single_pass_physical_count: torch.Tensor): + def append(self, forward_pass_id: int, gatherer_key: str, single_pass_physical_count: torch.Tensor): self._records.append(dict( forward_pass_id=forward_pass_id, rank=TODO, @@ -255,7 +256,7 @@ def postprocess_dumps(cls, physical_dumps: List[Any], physical_to_logical_map: t def __init__(self): self._physical_count = torch.zeros((num_layers, num_local_physical_experts)) - def append(self, forward_pass_id: int, single_pass_physical_count: torch.Tensor): + def append(self, forward_pass_id: int, gatherer_key: str, single_pass_physical_count: torch.Tensor): self._physical_count += single_pass_physical_count def reset(self): From 097c6547a8d3c48106718a48da77661056c46d35 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:42:14 +0800 Subject: [PATCH 0149/1089] more --- python/sglang/srt/managers/expert_distribution.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index a353fc8cb52..467656340e0 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -228,6 +228,7 @@ def append(self, forward_pass_id: int, gatherer_key: str, single_pass_physical_c self._records.append(dict( forward_pass_id=forward_pass_id, rank=TODO, + gatherer_key=gatherer_key, physical_count=single_pass_physical_count.tolist(), )) From fae53ff82c60bf1dd9fcb932e78b1259c67162b8 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:42:56 +0800 Subject: [PATCH 0150/1089] more --- python/sglang/srt/managers/expert_distribution.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 467656340e0..cd4e2309da7 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -161,6 +161,7 @@ def on_deepep_dispatch_normal(self, layer_idx: int, num_recv_tokens_per_expert_l # TODO Wait for LowLatency DeepEP merging # e.g. use naive tensor copying +# need to consider CUDA graph, e.g. add initialization and after-end class _DeepepLowLatencySinglePassGatherer(_SinglePassGatherer): pass From b84adaf073873f3a21b085047a830a6950f93367 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:44:07 +0800 Subject: [PATCH 0151/1089] more --- python/sglang/srt/managers/expert_distribution.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index cd4e2309da7..ad8ce23d6af 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -2,6 +2,7 @@ from abc import ABC from contextlib import contextmanager from copy import deepcopy +from dataclasses import dataclass from typing import List, Type, Any, Optional import torch @@ -269,3 +270,12 @@ def dump(self): rank=TODO, physical_count=self._physical_count.tolist(), ) + + +# --------------------------------------- Misc ----------------------------------------- + +@dataclass +class ModelExpertInfo: + num_layers: int + num_local_physical_experts: int + num_logical_experts: int From bafc37b6a383baa0f43ed91873404cbbe9103877 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:45:49 +0800 Subject: [PATCH 0152/1089] more --- .../srt/managers/expert_distribution.py | 22 ++++++++++++------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index ad8ce23d6af..03a85609b72 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -17,12 +17,13 @@ class ExpertDistributionRecorder: """Global expert distribution recording""" - def __init__(self, server_args: ServerArgs): + def __init__(self, server_args: ServerArgs, metadata: "ModelExpertMetadata"): self._recording = False self._current_layer_idx = Withable() - self._accumulator = _Accumulator.init_new() + self._metadata = metadata + self._accumulator = _Accumulator.init_new(metadata) self._single_pass_gatherers = { - k: _SinglePassGatherer.init_new(server_args) + k: _SinglePassGatherer.init_new(server_args, metadata) for k in self._accumulator.get_single_pass_gatherer_keys() } @@ -174,8 +175,8 @@ class _DeepepLowLatencySinglePassGatherer(_SinglePassGatherer): class _Accumulator(ABC): @staticmethod - def init_new() -> "_Accumulator": - return _Accumulator.get_class()() + def init_new(metadata: "ModelExpertMetadata") -> "_Accumulator": + return _Accumulator.get_class()(metadata) @staticmethod def get_class() -> Type["_Accumulator"]: @@ -183,6 +184,9 @@ def get_class() -> Type["_Accumulator"]: return _DetailAccumulator return _StatAccumulator + def __init__(self, metadata: "ModelExpertMetadata"): + self._metadata = metadata + def get_single_pass_gatherer_keys(self): return [_SINGLE_PASS_GATHERER_KEY_PRIMARY] @@ -213,7 +217,8 @@ def postprocess_dumps(cls, physical_dumps: List[Any], physical_to_logical_map: t for record in physical_dump ] - def __init__(self): + def __init__(self, metadata: "ModelExpertMetadata"): + super().__init__(metadata) self._records = [] def get_single_pass_gatherer_keys(self): @@ -256,7 +261,8 @@ def postprocess_dumps(cls, physical_dumps: List[Any], physical_to_logical_map: t layer_index, local_physical_expert_index] return dict(logical_count=logical_count) - def __init__(self): + def __init__(self, metadata: "ModelExpertMetadata"): + super().__init__(metadata) self._physical_count = torch.zeros((num_layers, num_local_physical_experts)) def append(self, forward_pass_id: int, gatherer_key: str, single_pass_physical_count: torch.Tensor): @@ -275,7 +281,7 @@ def dump(self): # --------------------------------------- Misc ----------------------------------------- @dataclass -class ModelExpertInfo: +class ModelExpertMetadata: num_layers: int num_local_physical_experts: int num_logical_experts: int From 0ddaadb8ea893ca54bbe5ba5880351d744eb0475 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:46:27 +0800 Subject: [PATCH 0153/1089] more --- python/sglang/srt/managers/expert_distribution.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 03a85609b72..fa12e54a844 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -101,11 +101,14 @@ def postprocess_dumps(physical_dumps: List[Any], physical_to_logical_map: torch. class _SinglePassGatherer(ABC): @staticmethod - def init_new(server_args: ServerArgs) -> "_SinglePassGatherer": + def init_new(server_args: ServerArgs, metadata: "ModelExpertMetadata") -> "_SinglePassGatherer": if server_args.enable_deepep_moe: # TODO DeepEP low latency - return _DeepepNormalSinglePassGatherer() - return _LayerBasedSinglePassGatherer() + return _DeepepNormalSinglePassGatherer(metadata) + return _LayerBasedSinglePassGatherer(metadata) + + def __init__(self, metadata: "ModelExpertMetadata"): + self._metadata = metadata def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor): pass @@ -121,7 +124,8 @@ def collect(self) -> torch.Tensor: class _LayerBasedSinglePassGatherer(_SinglePassGatherer): - def __init__(self): + def __init__(self, metadata: "ModelExpertMetadata"): + super().__init__(metadata) self._num_recv_tokens_per_expert_list_of_layer = {} def _on_layer_data(self, layer_idx: int, num_recv_tokens_per_expert_list: List[int]): From 9e845f8de6522e480cd8bc6eb107a931ee2a11ea Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:47:06 +0800 Subject: [PATCH 0154/1089] more --- .../srt/managers/expert_distribution.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index fa12e54a844..279c0ccb408 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -131,7 +131,7 @@ def __init__(self, metadata: "ModelExpertMetadata"): def _on_layer_data(self, layer_idx: int, num_recv_tokens_per_expert_list: List[int]): # TODO for TBO, we may need to relax this restriction assert layer_idx not in self._num_recv_tokens_per_expert_list_of_layer - assert 0 <= layer_idx < num_layers + assert 0 <= layer_idx < self._metadata.num_layers self._num_recv_tokens_per_expert_list_of_layer[layer_idx] = num_recv_tokens_per_expert_list def reset(self): @@ -139,8 +139,9 @@ def reset(self): def collect(self) -> torch.Tensor: data = [ - self._num_recv_tokens_per_expert_list_of_layer.get(layer_index) or ([0] * num_local_physical_experts) - for layer_index in range(num_layers) + self._num_recv_tokens_per_expert_list_of_layer.get(layer_index) or ( + [0] * self._metadata.num_local_physical_experts) + for layer_index in range(self._metadata.num_layers) ] return torch.tensor(data) @@ -151,7 +152,7 @@ def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor): topk_ids_list = topk_ids.to("cpu", non_blocking=True).numpy().tolist() torch.cuda.synchronize() - num_recv_tokens_per_expert_list = [0] * num_local_physical_experts + num_recv_tokens_per_expert_list = [0] * self._metadata.num_local_physical_experts for token_record in topk_ids_list: for expert_idx in token_record: num_recv_tokens_per_expert_list[expert_idx] += 1 @@ -253,12 +254,12 @@ def dump(self): class _StatAccumulator(_Accumulator): @classmethod def postprocess_dumps(cls, physical_dumps: List[Any], physical_to_logical_map: torch.Tensor): - logical_count = torch.zeros((num_layers, num_logical_experts)) + logical_count = torch.zeros((metadata.num_layers, metadata.num_logical_experts)) # Most naive implementation, can optimize if it is bottleneck for physical_dump in physical_dumps: - for layer_index in range(num_layers): - for local_physical_expert_index in range(num_local_physical_experts): - global_physical_expert_index = num_local_physical_experts * physical_dump[ + for layer_index in range(metadata.num_layers): + for local_physical_expert_index in range(metadata.num_local_physical_experts): + global_physical_expert_index = metadata.num_local_physical_experts * physical_dump[ 'rank'] + local_physical_expert_index logical_expert_index = physical_to_logical_map[layer_index, global_physical_expert_index] logical_count[layer_index, logical_expert_index] += physical_dump['physical_count'][ @@ -267,7 +268,7 @@ def postprocess_dumps(cls, physical_dumps: List[Any], physical_to_logical_map: t def __init__(self, metadata: "ModelExpertMetadata"): super().__init__(metadata) - self._physical_count = torch.zeros((num_layers, num_local_physical_experts)) + self._physical_count = torch.zeros((self._metadata.num_layers, self._metadata.num_local_physical_experts)) def append(self, forward_pass_id: int, gatherer_key: str, single_pass_physical_count: torch.Tensor): self._physical_count += single_pass_physical_count From 9f7d540c55a15dfb3e9987fb1041c850f78c6f05 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:47:43 +0800 Subject: [PATCH 0155/1089] more --- python/sglang/srt/managers/expert_distribution.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 279c0ccb408..a19c0b8639e 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -92,7 +92,8 @@ def dump_record(self): global_expert_distribution_recorder: Optional[ExpertDistributionRecorder] = None -def postprocess_dumps(physical_dumps: List[Any], physical_to_logical_map: torch.Tensor): +def postprocess_dumps(physical_dumps: List[Any], physical_to_logical_map: torch.Tensor, + metadata: "ModelExpertMetadata"): return _Accumulator.get_class().postprocess_dumps(physical_dumps, physical_to_logical_map) @@ -199,7 +200,8 @@ def get_single_pass_gatherer_key(self, debug_name: str): return _SINGLE_PASS_GATHERER_KEY_PRIMARY @classmethod - def postprocess_dumps(cls, physical_dumps: List[Any], physical_to_logical_map: torch.Tensor): + def postprocess_dumps(cls, physical_dumps: List[Any], physical_to_logical_map: torch.Tensor, + metadata: "ModelExpertMetadata"): raise NotImplementedError def append(self, forward_pass_id: int, gatherer_key: str, single_pass_physical_count: torch.Tensor): @@ -214,7 +216,8 @@ def dump(self): class _DetailAccumulator(_Accumulator): @classmethod - def postprocess_dumps(cls, physical_dumps: List[Any], physical_to_logical_map: torch.Tensor): + def postprocess_dumps(cls, physical_dumps: List[Any], physical_to_logical_map: torch.Tensor, + metadata: "ModelExpertMetadata"): # Do not convert to logical since we want all details return [ record @@ -253,7 +256,8 @@ def dump(self): class _StatAccumulator(_Accumulator): @classmethod - def postprocess_dumps(cls, physical_dumps: List[Any], physical_to_logical_map: torch.Tensor): + def postprocess_dumps(cls, physical_dumps: List[Any], physical_to_logical_map: torch.Tensor, + metadata: "ModelExpertMetadata"): logical_count = torch.zeros((metadata.num_layers, metadata.num_logical_experts)) # Most naive implementation, can optimize if it is bottleneck for physical_dump in physical_dumps: From 32b9dc23795a7665f01910f1a0c3090c946142a7 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:48:24 +0800 Subject: [PATCH 0156/1089] more --- python/sglang/srt/managers/expert_distribution.py | 15 ++++++++------- python/sglang/srt/model_executor/model_runner.py | 2 +- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index a19c0b8639e..fc684fb0fc9 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -17,11 +17,11 @@ class ExpertDistributionRecorder: """Global expert distribution recording""" - def __init__(self, server_args: ServerArgs, metadata: "ModelExpertMetadata"): + def __init__(self, server_args: ServerArgs, metadata: "ModelExpertMetadata", rank: int): self._recording = False self._current_layer_idx = Withable() self._metadata = metadata - self._accumulator = _Accumulator.init_new(metadata) + self._accumulator = _Accumulator.init_new(metadata, rank) self._single_pass_gatherers = { k: _SinglePassGatherer.init_new(server_args, metadata) for k in self._accumulator.get_single_pass_gatherer_keys() @@ -181,8 +181,8 @@ class _DeepepLowLatencySinglePassGatherer(_SinglePassGatherer): class _Accumulator(ABC): @staticmethod - def init_new(metadata: "ModelExpertMetadata") -> "_Accumulator": - return _Accumulator.get_class()(metadata) + def init_new(metadata: "ModelExpertMetadata", rank: int) -> "_Accumulator": + return _Accumulator.get_class()(metadata, rank) @staticmethod def get_class() -> Type["_Accumulator"]: @@ -190,8 +190,9 @@ def get_class() -> Type["_Accumulator"]: return _DetailAccumulator return _StatAccumulator - def __init__(self, metadata: "ModelExpertMetadata"): + def __init__(self, metadata: "ModelExpertMetadata", rank: int): self._metadata = metadata + self._rank = rank def get_single_pass_gatherer_keys(self): return [_SINGLE_PASS_GATHERER_KEY_PRIMARY] @@ -242,7 +243,7 @@ def get_single_pass_gatherer_key(self, debug_name: str): def append(self, forward_pass_id: int, gatherer_key: str, single_pass_physical_count: torch.Tensor): self._records.append(dict( forward_pass_id=forward_pass_id, - rank=TODO, + rank=self._rank, gatherer_key=gatherer_key, physical_count=single_pass_physical_count.tolist(), )) @@ -282,7 +283,7 @@ def reset(self): def dump(self): return dict( - rank=TODO, + rank=self._rank, physical_count=self._physical_count.tolist(), ) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index edcdaa45de3..27e732be4ed 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -215,7 +215,7 @@ def initialize(self, min_per_gpu_memory: float): if self.spec_algorithm.is_eagle3() and not self.is_draft_worker: self.model.set_eagle3_layers_to_capture() - self.expert_distribution_recorder = ExpertDistributionRecorder(server_args) + self.expert_distribution_recorder = ExpertDistributionRecorder(server_args, TODO, TODO) global global_expert_distribution_recorder global_expert_distribution_recorder = self.expert_distribution_recorder From b6ea1d74b84d4c188257bd944de9cf904cc6d238 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:49:17 +0800 Subject: [PATCH 0157/1089] more --- python/sglang/srt/model_executor/model_runner.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 27e732be4ed..df39711a3f0 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -44,7 +44,8 @@ from sglang.srt.layers.sampler import Sampler from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model from sglang.srt.lora.lora_manager import LoRAManager -from sglang.srt.managers.expert_distribution import global_expert_distribution_recorder, ExpertDistributionRecorder +from sglang.srt.managers.expert_distribution import global_expert_distribution_recorder, ExpertDistributionRecorder, \ + ModelExpertMetadata from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.mem_cache.memory_pool import ( DoubleSparseTokenToKVPool, @@ -215,7 +216,12 @@ def initialize(self, min_per_gpu_memory: float): if self.spec_algorithm.is_eagle3() and not self.is_draft_worker: self.model.set_eagle3_layers_to_capture() - self.expert_distribution_recorder = ExpertDistributionRecorder(server_args, TODO, TODO) + model_expert_metadata = ModelExpertMetadata( + num_layers=TODO, + num_local_physical_experts=TODO, + num_logical_experts=TODO, + ) + self.expert_distribution_recorder = ExpertDistributionRecorder(server_args, model_expert_metadata, TODO) global global_expert_distribution_recorder global_expert_distribution_recorder = self.expert_distribution_recorder From 32ac6ef5586df0641124df54ba5a60096e3dd8a9 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:51:05 +0800 Subject: [PATCH 0158/1089] more --- python/sglang/srt/model_executor/model_runner.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index df39711a3f0..cc2b44708f7 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -221,7 +221,11 @@ def initialize(self, min_per_gpu_memory: float): num_local_physical_experts=TODO, num_logical_experts=TODO, ) - self.expert_distribution_recorder = ExpertDistributionRecorder(server_args, model_expert_metadata, TODO) + self.expert_distribution_recorder = ExpertDistributionRecorder( + server_args, model_expert_metadata, + # TODO handle DP!=TP case + rank=self.tp_rank, + ) global global_expert_distribution_recorder global_expert_distribution_recorder = self.expert_distribution_recorder From c86c997b2382a345943d4c7c87a30f9da5e0d30e Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:51:34 +0800 Subject: [PATCH 0159/1089] more --- python/sglang/srt/managers/expert_distribution.py | 8 ++++++++ python/sglang/srt/model_executor/model_runner.py | 6 +----- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index fc684fb0fc9..a00dae304bd 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -295,3 +295,11 @@ class ModelExpertMetadata: num_layers: int num_local_physical_experts: int num_logical_experts: int + + @staticmethod + def from_model(model): + return ModelExpertMetadata( + num_layers=TODO, + num_local_physical_experts=TODO, + num_logical_experts=TODO, + ) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index cc2b44708f7..b9a7128dcd9 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -216,11 +216,7 @@ def initialize(self, min_per_gpu_memory: float): if self.spec_algorithm.is_eagle3() and not self.is_draft_worker: self.model.set_eagle3_layers_to_capture() - model_expert_metadata = ModelExpertMetadata( - num_layers=TODO, - num_local_physical_experts=TODO, - num_logical_experts=TODO, - ) + model_expert_metadata = ModelExpertMetadata.from_model(self.model) self.expert_distribution_recorder = ExpertDistributionRecorder( server_args, model_expert_metadata, # TODO handle DP!=TP case From a95be4532006b496de9e67f90e412498c33efff6 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:53:37 +0800 Subject: [PATCH 0160/1089] more --- .../srt/managers/expert_distribution.py | 22 ++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index a00dae304bd..f42c0c8c374 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -6,8 +6,10 @@ from typing import List, Type, Any, Optional import torch +from sglang.srt.configs.deepseekvl2 import DeepseekV2Config from sglang.srt.server_args import ServerArgs from sglang.srt.utils import Withable, get_bool_env_var +from transformers import Qwen2MoeConfig logger = logging.getLogger(__name__) @@ -298,8 +300,18 @@ class ModelExpertMetadata: @staticmethod def from_model(model): - return ModelExpertMetadata( - num_layers=TODO, - num_local_physical_experts=TODO, - num_logical_experts=TODO, - ) + config = model.config + if isinstance(config, DeepseekV2Config): + return ModelExpertMetadata( + num_layers=TODO, + num_local_physical_experts=TODO, + num_logical_experts=TODO, + ) + # TODO is it this class? + if isinstance(config, Qwen2MoeConfig): + return ModelExpertMetadata( + num_layers=TODO, + num_local_physical_experts=TODO, + num_logical_experts=TODO, + ) + return None From d7f7ba2345877d6eb626b3e1f1d746b12085b540 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:54:41 +0800 Subject: [PATCH 0161/1089] more --- python/sglang/srt/managers/expert_distribution.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index f42c0c8c374..40b384a3cf0 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -303,15 +303,15 @@ def from_model(model): config = model.config if isinstance(config, DeepseekV2Config): return ModelExpertMetadata( - num_layers=TODO, + num_layers=config.num_hidden_layers, num_local_physical_experts=TODO, - num_logical_experts=TODO, + num_logical_experts=config.n_routed_experts, ) # TODO is it this class? if isinstance(config, Qwen2MoeConfig): return ModelExpertMetadata( - num_layers=TODO, + num_layers=config.num_hidden_layers, num_local_physical_experts=TODO, - num_logical_experts=TODO, + num_logical_experts=config.num_experts, ) return None From a449b80b249a7d84b9ac3c6c79ec51e2e799a5a8 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:55:33 +0800 Subject: [PATCH 0162/1089] more --- .../sglang/srt/managers/expert_distribution.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 40b384a3cf0..930c4a6aae2 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -302,16 +302,25 @@ class ModelExpertMetadata: def from_model(model): config = model.config if isinstance(config, DeepseekV2Config): - return ModelExpertMetadata( + return ModelExpertMetadata._init_new( num_layers=config.num_hidden_layers, - num_local_physical_experts=TODO, num_logical_experts=config.n_routed_experts, ) # TODO is it this class? if isinstance(config, Qwen2MoeConfig): - return ModelExpertMetadata( + return ModelExpertMetadata._init_new( num_layers=config.num_hidden_layers, - num_local_physical_experts=TODO, num_logical_experts=config.num_experts, ) return None + + @staticmethod + def _init_new( + num_layers: int, + num_logical_experts: int, + ): + return ModelExpertMetadata( + num_layers=num_layers, + num_logical_experts=num_logical_experts, + num_local_physical_experts=TODO, + ) From 137e8cd4bcfdc6fdc8695d36a318e4be46f681e9 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:55:53 +0800 Subject: [PATCH 0163/1089] more --- python/sglang/srt/managers/expert_distribution.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 930c4a6aae2..0eae6e9c6f3 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -19,7 +19,7 @@ class ExpertDistributionRecorder: """Global expert distribution recording""" - def __init__(self, server_args: ServerArgs, metadata: "ModelExpertMetadata", rank: int): + def __init__(self, server_args: ServerArgs, metadata: Optional["ModelExpertMetadata"], rank: int): self._recording = False self._current_layer_idx = Withable() self._metadata = metadata From 3b092be491b52e34ded20511b20bdfae1f336811 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:56:31 +0800 Subject: [PATCH 0164/1089] more --- python/sglang/srt/managers/expert_distribution.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 0eae6e9c6f3..28366ac168e 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -312,7 +312,7 @@ def from_model(model): num_layers=config.num_hidden_layers, num_logical_experts=config.num_experts, ) - return None + return ModelExpertMetadata._init_dummy() @staticmethod def _init_new( @@ -324,3 +324,11 @@ def _init_new( num_logical_experts=num_logical_experts, num_local_physical_experts=TODO, ) + + @staticmethod + def _init_dummy(): + return ModelExpertMetadata( + num_layers=1, + num_local_physical_experts=1, + num_logical_experts=1, + ) From a62ab027f11de8e193716739213525239b814ea3 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:56:56 +0800 Subject: [PATCH 0165/1089] more --- python/sglang/srt/managers/expert_distribution.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 28366ac168e..cb445b93383 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -19,7 +19,7 @@ class ExpertDistributionRecorder: """Global expert distribution recording""" - def __init__(self, server_args: ServerArgs, metadata: Optional["ModelExpertMetadata"], rank: int): + def __init__(self, server_args: ServerArgs, metadata: "ModelExpertMetadata", rank: int): self._recording = False self._current_layer_idx = Withable() self._metadata = metadata From 6660cd5aee951ad9d7aff18609e186121619aa03 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:58:02 +0800 Subject: [PATCH 0166/1089] more --- python/sglang/srt/managers/expert_distribution.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index cb445b93383..c1fc224f465 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -7,6 +7,7 @@ import torch from sglang.srt.configs.deepseekvl2 import DeepseekV2Config +from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.server_args import ServerArgs from sglang.srt.utils import Withable, get_bool_env_var from transformers import Qwen2MoeConfig @@ -322,7 +323,8 @@ def _init_new( return ModelExpertMetadata( num_layers=num_layers, num_logical_experts=num_logical_experts, - num_local_physical_experts=TODO, + # TODO handle more complex cases, e.g. duplicate some experts + num_local_physical_experts=num_logical_experts // get_tensor_model_parallel_world_size(), ) @staticmethod From f8ae307cfc1edb56461083ac743dddb876cca8e5 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:59:20 +0800 Subject: [PATCH 0167/1089] more --- .../sglang/srt/managers/expert_distribution.py | 16 ++-------------- python/sglang/srt/models/deepseek_v2.py | 7 ++++++- python/sglang/srt/models/qwen2_moe.py | 7 ++++++- 3 files changed, 14 insertions(+), 16 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index c1fc224f465..6f760103dd6 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -10,7 +10,6 @@ from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.server_args import ServerArgs from sglang.srt.utils import Withable, get_bool_env_var -from transformers import Qwen2MoeConfig logger = logging.getLogger(__name__) @@ -301,22 +300,11 @@ class ModelExpertMetadata: @staticmethod def from_model(model): - config = model.config - if isinstance(config, DeepseekV2Config): - return ModelExpertMetadata._init_new( - num_layers=config.num_hidden_layers, - num_logical_experts=config.n_routed_experts, - ) - # TODO is it this class? - if isinstance(config, Qwen2MoeConfig): - return ModelExpertMetadata._init_new( - num_layers=config.num_hidden_layers, - num_logical_experts=config.num_experts, - ) + return TDO return ModelExpertMetadata._init_dummy() @staticmethod - def _init_new( + def init_new( num_layers: int, num_logical_experts: int, ): diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 5e0755ad6ca..74d202717f0 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -66,7 +66,7 @@ ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.managers.expert_distribution import global_expert_distribution_recorder +from sglang.srt.managers.expert_distribution import global_expert_distribution_recorder, ModelExpertMetadata from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_loader.weight_utils import default_weight_loader @@ -1491,6 +1491,11 @@ def set_embed_and_head(self, embed, head): torch.cuda.empty_cache() torch.cuda.synchronize() + def get_model_expert_metadata(self): + return ModelExpertMetadata.init_new( + num_layers=self.config.num_hidden_layers, + num_logical_experts=self.config.n_routed_experts, + ) class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM): pass diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index e110ff49455..dd92d957de8 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -41,7 +41,7 @@ ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.managers.expert_distribution import global_expert_distribution_recorder +from sglang.srt.managers.expert_distribution import global_expert_distribution_recorder, ModelExpertMetadata from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import add_prefix @@ -484,5 +484,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ) weight_loader(param, loaded_weight) + def get_model_expert_metadata(self): + return ModelExpertMetadata.init_new( + num_layers=config.num_hidden_layers, + num_logical_experts=config.num_experts, + ) EntryClass = Qwen2MoeForCausalLM From 23047807d212230f69f2b5b53258f853ecfa330b Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:59:33 +0800 Subject: [PATCH 0168/1089] more --- python/sglang/srt/managers/expert_distribution.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 6f760103dd6..826183b3826 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -300,7 +300,8 @@ class ModelExpertMetadata: @staticmethod def from_model(model): - return TDO + if hasattr(model, "get_model_expert_metadata"): + return model.get_model_expert_metadata() return ModelExpertMetadata._init_dummy() @staticmethod From 7e4bd0100b6bf17144fba45af4f05d6cb94f9861 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 09:59:59 +0800 Subject: [PATCH 0169/1089] more --- python/sglang/srt/managers/expert_distribution.py | 12 ------------ python/sglang/srt/models/deepseek_v2.py | 3 ++- python/sglang/srt/models/qwen2_moe.py | 7 ++++--- 3 files changed, 6 insertions(+), 16 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 826183b3826..117e20cb3ff 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -304,18 +304,6 @@ def from_model(model): return model.get_model_expert_metadata() return ModelExpertMetadata._init_dummy() - @staticmethod - def init_new( - num_layers: int, - num_logical_experts: int, - ): - return ModelExpertMetadata( - num_layers=num_layers, - num_logical_experts=num_logical_experts, - # TODO handle more complex cases, e.g. duplicate some experts - num_local_physical_experts=num_logical_experts // get_tensor_model_parallel_world_size(), - ) - @staticmethod def _init_dummy(): return ModelExpertMetadata( diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 74d202717f0..9ae5ee46417 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1492,9 +1492,10 @@ def set_embed_and_head(self, embed, head): torch.cuda.synchronize() def get_model_expert_metadata(self): - return ModelExpertMetadata.init_new( + return ModelExpertMetadata( num_layers=self.config.num_hidden_layers, num_logical_experts=self.config.n_routed_experts, + num_local_physical_experts=TODO, ) class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM): diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index dd92d957de8..5ad76d173dc 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -485,9 +485,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader(param, loaded_weight) def get_model_expert_metadata(self): - return ModelExpertMetadata.init_new( - num_layers=config.num_hidden_layers, - num_logical_experts=config.num_experts, + return ModelExpertMetadata( + num_layers=self.config.num_hidden_layers, + num_logical_experts=self.config.num_experts, + num_local_physical_experts=TODO, ) EntryClass = Qwen2MoeForCausalLM From df8524fcd7bb14228a2170da7a34d3d8502e7680 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 10:01:45 +0800 Subject: [PATCH 0170/1089] more --- python/sglang/srt/models/deepseek_v2.py | 3 ++- python/sglang/srt/models/qwen2_moe.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 9ae5ee46417..0c0e2de0b6a 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1495,7 +1495,8 @@ def get_model_expert_metadata(self): return ModelExpertMetadata( num_layers=self.config.num_hidden_layers, num_logical_experts=self.config.n_routed_experts, - num_local_physical_experts=TODO, + # TODO handle more complex cases like duplicating experts on different GPUs + num_local_physical_experts=self.config.n_routed_experts // get_tensor_model_parallel_world_size(), ) class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM): diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 5ad76d173dc..fd65e25f050 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -488,7 +488,8 @@ def get_model_expert_metadata(self): return ModelExpertMetadata( num_layers=self.config.num_hidden_layers, num_logical_experts=self.config.num_experts, - num_local_physical_experts=TODO, + # TODO handle more complex cases like duplicating experts on different GPUs + num_local_physical_experts=self.config.num_experts // get_tensor_model_parallel_world_size(), ) EntryClass = Qwen2MoeForCausalLM From 189de3fdba725e9afefd56a52458732231b979e6 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 10:01:59 +0800 Subject: [PATCH 0171/1089] more --- python/sglang/srt/managers/expert_distribution.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 117e20cb3ff..d3fb92bc7d2 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -6,8 +6,6 @@ from typing import List, Type, Any, Optional import torch -from sglang.srt.configs.deepseekvl2 import DeepseekV2Config -from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.server_args import ServerArgs from sglang.srt.utils import Withable, get_bool_env_var From e867aceb587ab6f4aa9c9d68c55ed4b7fd78f7be Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 10:02:36 +0800 Subject: [PATCH 0172/1089] more --- python/sglang/srt/managers/expert_distribution.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index d3fb92bc7d2..eac10f1717b 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -46,13 +46,13 @@ def _on_forward_pass_end(self, forward_pass_id: int): def on_select_experts(self, topk_ids: torch.Tensor): if not self._recording: return - gatherer = self._single_pass_gatherers[self._accumulator.get_single_pass_gatherer_key()] + gatherer = self._single_pass_gatherers[self._accumulator.get_single_pass_gatherer_key(debug_name)] gatherer.on_select_experts(layer_idx=self._current_layer_idx.value, topk_ids=topk_ids) def on_deepep_dispatch_normal(self, num_recv_tokens_per_expert_list: List[int]): if not self._recording: return - gatherer = self._single_pass_gatherers[self._accumulator.get_single_pass_gatherer_key()] + gatherer = self._single_pass_gatherers[self._accumulator.get_single_pass_gatherer_key(debug_name)] gatherer.on_deepep_dispatch_normal(self._current_layer_idx.value, num_recv_tokens_per_expert_list) def _reset(self): @@ -94,7 +94,7 @@ def dump_record(self): def postprocess_dumps(physical_dumps: List[Any], physical_to_logical_map: torch.Tensor, metadata: "ModelExpertMetadata"): - return _Accumulator.get_class().postprocess_dumps(physical_dumps, physical_to_logical_map) + return _Accumulator.get_class().postprocess_dumps(physical_dumps, physical_to_logical_map, metadata) # --------------------------------------- SinglePassGatherer ----------------------------------------- @@ -226,8 +226,8 @@ def postprocess_dumps(cls, physical_dumps: List[Any], physical_to_logical_map: t for record in physical_dump ] - def __init__(self, metadata: "ModelExpertMetadata"): - super().__init__(metadata) + def __init__(self, metadata: "ModelExpertMetadata", rank: int): + super().__init__(metadata, rank) self._records = [] def get_single_pass_gatherer_keys(self): @@ -271,8 +271,8 @@ def postprocess_dumps(cls, physical_dumps: List[Any], physical_to_logical_map: t layer_index, local_physical_expert_index] return dict(logical_count=logical_count) - def __init__(self, metadata: "ModelExpertMetadata"): - super().__init__(metadata) + def __init__(self, metadata: "ModelExpertMetadata", rank: int): + super().__init__(metadata, rank) self._physical_count = torch.zeros((self._metadata.num_layers, self._metadata.num_local_physical_experts)) def append(self, forward_pass_id: int, gatherer_key: str, single_pass_physical_count: torch.Tensor): From 9cf2f0b5fb13526cd99515f665c8e875b7da6a48 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 10:03:41 +0800 Subject: [PATCH 0173/1089] more --- .../sglang/srt/managers/expert_distribution.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index eac10f1717b..89ebbd617bd 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -20,6 +20,7 @@ class ExpertDistributionRecorder: def __init__(self, server_args: ServerArgs, metadata: "ModelExpertMetadata", rank: int): self._recording = False self._current_layer_idx = Withable() + self._current_debug_name = Withable() self._metadata = metadata self._accumulator = _Accumulator.init_new(metadata, rank) self._single_pass_gatherers = { @@ -30,6 +31,9 @@ def __init__(self, server_args: ServerArgs, metadata: "ModelExpertMetadata", ran def with_current_layer(self, layer_idx): return self._current_layer_idx.with_value(layer_idx) + def with_debug_name(self, debug_name): + return self._current_debug_name.with_value(debug_name) + @contextmanager def with_forward_pass(self, forward_pass_id: int): try: @@ -46,13 +50,15 @@ def _on_forward_pass_end(self, forward_pass_id: int): def on_select_experts(self, topk_ids: torch.Tensor): if not self._recording: return - gatherer = self._single_pass_gatherers[self._accumulator.get_single_pass_gatherer_key(debug_name)] + gatherer = self._single_pass_gatherers[ + self._accumulator.get_single_pass_gatherer_key(self._current_debug_name.value)] gatherer.on_select_experts(layer_idx=self._current_layer_idx.value, topk_ids=topk_ids) def on_deepep_dispatch_normal(self, num_recv_tokens_per_expert_list: List[int]): if not self._recording: return - gatherer = self._single_pass_gatherers[self._accumulator.get_single_pass_gatherer_key(debug_name)] + gatherer = self._single_pass_gatherers[ + self._accumulator.get_single_pass_gatherer_key(self._current_debug_name.value)] gatherer.on_deepep_dispatch_normal(self._current_layer_idx.value, num_recv_tokens_per_expert_list) def _reset(self): @@ -197,7 +203,7 @@ def __init__(self, metadata: "ModelExpertMetadata", rank: int): def get_single_pass_gatherer_keys(self): return [_SINGLE_PASS_GATHERER_KEY_PRIMARY] - def get_single_pass_gatherer_key(self, debug_name: str): + def get_single_pass_gatherer_key(self, debug_name: Optional[str]): return _SINGLE_PASS_GATHERER_KEY_PRIMARY @classmethod @@ -235,9 +241,9 @@ def get_single_pass_gatherer_keys(self): return [_SINGLE_PASS_GATHERER_KEY_PRIMARY, "child_a", "child_b"] return super().get_single_pass_gatherer_keys() - def get_single_pass_gatherer_key(self, debug_name: str): + def get_single_pass_gatherer_key(self, debug_name: Optional[str]): if False: # TODO `server_args.enable_two_batch_overlap` - return debug_name + return debug_name or _SINGLE_PASS_GATHERER_KEY_PRIMARY return super().get_single_pass_gatherer_key(debug_name) def append(self, forward_pass_id: int, gatherer_key: str, single_pass_physical_count: torch.Tensor): From 0c03f39038f2752fdb7b16b50617ca96637b40d2 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 10:05:06 +0800 Subject: [PATCH 0174/1089] more --- test/srt/test_expert_distribution.py | 62 ++++++++++------------------ 1 file changed, 21 insertions(+), 41 deletions(-) diff --git a/test/srt/test_expert_distribution.py b/test/srt/test_expert_distribution.py index e3826303d16..62b1164c66b 100755 --- a/test/srt/test_expert_distribution.py +++ b/test/srt/test_expert_distribution.py @@ -70,51 +70,31 @@ def test_expert_distribution_record(self): ) self.assertEqual(response.status_code, 200) - # Verify the dumped file exists and has correct format - csv_files = glob.glob("expert_distribution_*.csv") - self.assertEqual( - len(csv_files), - 1, - f"Expected exactly one expert distribution CSV file {csv_files=}", - ) - - # Check CSV file format - with open(csv_files[0], "r") as f: - csv_reader = csv.reader(f) + # Check data rows + rows = list(csv_reader) + self.assertGreater(len(rows), 0, "CSV file should contain data rows") - # Check header - header = next(csv_reader) + for row in rows: + # Verify each row has 3 columns self.assertEqual( - header, - ["layer_id", "expert_id", "count"], - "CSV header should be 'layer_id,expert_id,count'", + len(row), + 3, + "Each row should have layer_id, expert_id and count", ) - # Check data rows - rows = list(csv_reader) - self.assertGreater(len(rows), 0, "CSV file should contain data rows") - - for row in rows: - # Verify each row has 3 columns - self.assertEqual( - len(row), - 3, - "Each row should have layer_id, expert_id and count", - ) - - # Verify data types - layer_id, expert_id, count = row - self.assertTrue( - layer_id.isdigit(), - f"layer_id should be an integer {row=} {rows=}", - ) - self.assertTrue( - expert_id.isdigit(), - f"expert_id should be an integer {row=} {rows=}", - ) - self.assertTrue( - count.isdigit(), f"count should be an integer {row=} {rows=}" - ) + # Verify data types + layer_id, expert_id, count = row + self.assertTrue( + layer_id.isdigit(), + f"layer_id should be an integer {row=} {rows=}", + ) + self.assertTrue( + expert_id.isdigit(), + f"expert_id should be an integer {row=} {rows=}", + ) + self.assertTrue( + count.isdigit(), f"count should be an integer {row=} {rows=}" + ) finally: kill_process_tree(process.pid) From b7864ad9ae374364c510b15c1a52990e1832df7b Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 10:05:18 +0800 Subject: [PATCH 0175/1089] more --- test/srt/test_expert_distribution.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/srt/test_expert_distribution.py b/test/srt/test_expert_distribution.py index 62b1164c66b..9363cf9859f 100755 --- a/test/srt/test_expert_distribution.py +++ b/test/srt/test_expert_distribution.py @@ -27,6 +27,9 @@ def tearDown(self): os.remove(f) def test_expert_distribution_record(self): + self._execute_core() + + def _execute_core(self): """Test expert distribution record endpoints""" process = popen_launch_server( # The feature is only implemented in deepseek_v2.py From 46e673bb02c28dd740757a8266d09fb0a8c1f760 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 10:06:41 +0800 Subject: [PATCH 0176/1089] more --- test/srt/test_expert_distribution.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/test/srt/test_expert_distribution.py b/test/srt/test_expert_distribution.py index 9363cf9859f..2425bc85920 100755 --- a/test/srt/test_expert_distribution.py +++ b/test/srt/test_expert_distribution.py @@ -1,10 +1,8 @@ -import csv import glob import os import unittest import requests - from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST, @@ -27,13 +25,17 @@ def tearDown(self): os.remove(f) def test_expert_distribution_record(self): - self._execute_core() + for model_path in [ + "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct", + "Qwen/Qwen1.5-MoE-A2.7B", + ]: + with self.subTest(model_path=model_path): + self._execute_core(model_path=model_path) - def _execute_core(self): + def _execute_core(self, model_path: str): """Test expert distribution record endpoints""" process = popen_launch_server( - # The feature is only implemented in deepseek_v2.py - "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct", + model_path, DEFAULT_URL_FOR_TEST, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, other_args=[ From 7af777e8747a684b50a93cfe70be5aa2b898694d Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 10:07:05 +0800 Subject: [PATCH 0177/1089] more --- test/srt/test_expert_distribution.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/srt/test_expert_distribution.py b/test/srt/test_expert_distribution.py index 2425bc85920..c329e0dfa2a 100755 --- a/test/srt/test_expert_distribution.py +++ b/test/srt/test_expert_distribution.py @@ -57,7 +57,7 @@ def _execute_core(self, model_path: str): "text": "The capital of France is", "sampling_params": { "temperature": 0, - "max_new_tokens": 32, + "max_new_tokens": 3, }, }, ) From 0363d396708ccf2611d520bf96c0f8497724ac30 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 10:07:26 +0800 Subject: [PATCH 0178/1089] more --- test/srt/test_expert_distribution.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/srt/test_expert_distribution.py b/test/srt/test_expert_distribution.py index c329e0dfa2a..3d5d905fe97 100755 --- a/test/srt/test_expert_distribution.py +++ b/test/srt/test_expert_distribution.py @@ -76,7 +76,7 @@ def _execute_core(self, model_path: str): self.assertEqual(response.status_code, 200) # Check data rows - rows = list(csv_reader) + data = response.json() self.assertGreater(len(rows), 0, "CSV file should contain data rows") for row in rows: From dde1118db879ae164d8a710b608c88ad7d62d2e0 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 10:09:06 +0800 Subject: [PATCH 0179/1089] more --- test/srt/test_expert_distribution.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/test/srt/test_expert_distribution.py b/test/srt/test_expert_distribution.py index 3d5d905fe97..498674f8955 100755 --- a/test/srt/test_expert_distribution.py +++ b/test/srt/test_expert_distribution.py @@ -25,15 +25,17 @@ def tearDown(self): os.remove(f) def test_expert_distribution_record(self): - for model_path in [ - "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct", - "Qwen/Qwen1.5-MoE-A2.7B", + for info in [ + dict(model_path="deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct", mode_detail=False), + dict(model_path="Qwen/Qwen1.5-MoE-A2.7B", mode_detail=False), + dict(model_path="Qwen/Qwen1.5-MoE-A2.7B", mode_detail=True), ]: - with self.subTest(model_path=model_path): - self._execute_core(model_path=model_path) + with self.subTest(info=info): + self._execute_core(**info) - def _execute_core(self, model_path: str): + def _execute_core(self, model_path: str, mode_detail: bool): """Test expert distribution record endpoints""" + os.environ["SGLANG_EXPERT_DISTRIBUTION_RECORDER_DETAIL"] = "1" if mode_detail else "0" process = popen_launch_server( model_path, DEFAULT_URL_FOR_TEST, From bb5da82e4826a1a080ac6695a0ffe148dd4914dd Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 10:09:47 +0800 Subject: [PATCH 0180/1089] more --- test/srt/test_expert_distribution.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/test/srt/test_expert_distribution.py b/test/srt/test_expert_distribution.py index 498674f8955..47910e5f9a9 100755 --- a/test/srt/test_expert_distribution.py +++ b/test/srt/test_expert_distribution.py @@ -79,6 +79,12 @@ def _execute_core(self, model_path: str, mode_detail: bool): # Check data rows data = response.json() + + if mode_detail: + TODO + else: + TODO + self.assertGreater(len(rows), 0, "CSV file should contain data rows") for row in rows: From b3c14d764d7acce9c9325378cb43439d24c449d3 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 10:10:36 +0800 Subject: [PATCH 0181/1089] more --- test/srt/test_expert_distribution.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/test/srt/test_expert_distribution.py b/test/srt/test_expert_distribution.py index 47910e5f9a9..d39608d1f8d 100755 --- a/test/srt/test_expert_distribution.py +++ b/test/srt/test_expert_distribution.py @@ -79,14 +79,13 @@ def _execute_core(self, model_path: str, mode_detail: bool): # Check data rows data = response.json() + print(f"{data=}") if mode_detail: - TODO + self.assertGreater(len(data), 0, "Should contain data rows") else: TODO - self.assertGreater(len(rows), 0, "CSV file should contain data rows") - for row in rows: # Verify each row has 3 columns self.assertEqual( From 90f723d2dc2e6eb795b68e3f6ee0b8cce73ca798 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 10:11:41 +0800 Subject: [PATCH 0182/1089] more --- test/srt/test_expert_distribution.py | 27 ++++----------------------- 1 file changed, 4 insertions(+), 23 deletions(-) diff --git a/test/srt/test_expert_distribution.py b/test/srt/test_expert_distribution.py index d39608d1f8d..436c2633a70 100755 --- a/test/srt/test_expert_distribution.py +++ b/test/srt/test_expert_distribution.py @@ -3,6 +3,7 @@ import unittest import requests +import torch from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST, @@ -84,29 +85,9 @@ def _execute_core(self, model_path: str, mode_detail: bool): if mode_detail: self.assertGreater(len(data), 0, "Should contain data rows") else: - TODO - - for row in rows: - # Verify each row has 3 columns - self.assertEqual( - len(row), - 3, - "Each row should have layer_id, expert_id and count", - ) - - # Verify data types - layer_id, expert_id, count = row - self.assertTrue( - layer_id.isdigit(), - f"layer_id should be an integer {row=} {rows=}", - ) - self.assertTrue( - expert_id.isdigit(), - f"expert_id should be an integer {row=} {rows=}", - ) - self.assertTrue( - count.isdigit(), f"count should be an integer {row=} {rows=}" - ) + logical_count = torch.tensor(data['logical_count']) + print(f"{logical_count=}") + self.assertTrue(logical_count.sum() > 0) finally: kill_process_tree(process.pid) From 00a93e32bf42bf0a3024418468a5d51e6c0d050c Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 10:12:12 +0800 Subject: [PATCH 0183/1089] more --- test/srt/test_expert_distribution.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/srt/test_expert_distribution.py b/test/srt/test_expert_distribution.py index 436c2633a70..4de137b6977 100755 --- a/test/srt/test_expert_distribution.py +++ b/test/srt/test_expert_distribution.py @@ -26,6 +26,7 @@ def tearDown(self): os.remove(f) def test_expert_distribution_record(self): + # TODO: Add tests for DeepEP gatherer (currently our CI cannot run that) for info in [ dict(model_path="deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct", mode_detail=False), dict(model_path="Qwen/Qwen1.5-MoE-A2.7B", mode_detail=False), From 68f0b6527bd94eb89b4beb732b09e8ab9c4b2b3f Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 10:13:29 +0800 Subject: [PATCH 0184/1089] more --- python/sglang/srt/managers/expert_distribution.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 89ebbd617bd..d95222d9404 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -48,18 +48,17 @@ def _on_forward_pass_end(self, forward_pass_id: int): gatherer.reset() def on_select_experts(self, topk_ids: torch.Tensor): - if not self._recording: - return - gatherer = self._single_pass_gatherers[ - self._accumulator.get_single_pass_gatherer_key(self._current_debug_name.value)] - gatherer.on_select_experts(layer_idx=self._current_layer_idx.value, topk_ids=topk_ids) + self._on_hook("on_select_experts", topk_ids=topk_ids) def on_deepep_dispatch_normal(self, num_recv_tokens_per_expert_list: List[int]): + self._on_hook("on_deepep_dispatch_normal", num_recv_tokens_per_expert_list=num_recv_tokens_per_expert_list) + + def _on_hook(self, hook_name: str, **kwargs): if not self._recording: return gatherer = self._single_pass_gatherers[ self._accumulator.get_single_pass_gatherer_key(self._current_debug_name.value)] - gatherer.on_deepep_dispatch_normal(self._current_layer_idx.value, num_recv_tokens_per_expert_list) + getattr(gatherer, hook_name)(layer_idx=self._current_layer_idx.value, **kwargs) def _reset(self): """Reset the expert distribution recorder.""" From c9a8ac2861688596f0e619a13fc279e7e222df30 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 10:14:37 +0800 Subject: [PATCH 0185/1089] more --- test/srt/test_expert_distribution.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/test/srt/test_expert_distribution.py b/test/srt/test_expert_distribution.py index 4de137b6977..546b0c872a6 100755 --- a/test/srt/test_expert_distribution.py +++ b/test/srt/test_expert_distribution.py @@ -1,4 +1,3 @@ -import glob import os import unittest @@ -15,16 +14,6 @@ class TestExpertDistribution(CustomTestCase): - def setUp(self): - # Clean up any existing expert distribution files before each test - for f in glob.glob("expert_distribution_*.csv"): - os.remove(f) - - def tearDown(self): - # Clean up any expert distribution files after each test - for f in glob.glob("expert_distribution_*.csv"): - os.remove(f) - def test_expert_distribution_record(self): # TODO: Add tests for DeepEP gatherer (currently our CI cannot run that) for info in [ From 0a22597e33c135849604394ebc7aeb897e1d4c59 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 10:15:41 +0800 Subject: [PATCH 0186/1089] more --- python/sglang/srt/managers/tokenizer_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index bbbc8daa384..517e3a926ce 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -656,7 +656,7 @@ async def dump_expert_distribution_record(self): raw_outputs: List[ExpertDistributionReqOutput] = await self.expert_distribution_communicator( ExpertDistributionReq.DUMP_RECORD) return expert_distribution.postprocess_dumps([output.dump_output for output in raw_outputs], - physical_to_logical_map=TODO) + physical_to_logical_map=TODO, metadata=TODO) async def update_weights_from_disk( self, From 4408c78544988a060e0a4077383cd14ecdaa54d0 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 10:17:09 +0800 Subject: [PATCH 0187/1089] more --- .../srt/managers/expert_distribution.py | 92 +++++++------------ python/sglang/srt/managers/expert_location.py | 22 +++++ 2 files changed, 57 insertions(+), 57 deletions(-) create mode 100644 python/sglang/srt/managers/expert_location.py diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index d95222d9404..42041e25c6a 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -2,10 +2,10 @@ from abc import ABC from contextlib import contextmanager from copy import deepcopy -from dataclasses import dataclass from typing import List, Type, Any, Optional import torch +from sglang.srt.managers.expert_location import ExpertLocationMetadata from sglang.srt.server_args import ServerArgs from sglang.srt.utils import Withable, get_bool_env_var @@ -17,14 +17,14 @@ class ExpertDistributionRecorder: """Global expert distribution recording""" - def __init__(self, server_args: ServerArgs, metadata: "ModelExpertMetadata", rank: int): + def __init__(self, server_args: ServerArgs, expert_location_metadata: "ExpertLocationMetadata", rank: int): self._recording = False self._current_layer_idx = Withable() self._current_debug_name = Withable() - self._metadata = metadata - self._accumulator = _Accumulator.init_new(metadata, rank) + self._expert_location_metadata = expert_location_metadata + self._accumulator = _Accumulator.init_new(expert_location_metadata, rank) self._single_pass_gatherers = { - k: _SinglePassGatherer.init_new(server_args, metadata) + k: _SinglePassGatherer.init_new(server_args, expert_location_metadata) for k in self._accumulator.get_single_pass_gatherer_keys() } @@ -98,8 +98,8 @@ def dump_record(self): def postprocess_dumps(physical_dumps: List[Any], physical_to_logical_map: torch.Tensor, - metadata: "ModelExpertMetadata"): - return _Accumulator.get_class().postprocess_dumps(physical_dumps, physical_to_logical_map, metadata) + expert_location_metadata: "ExpertLocationMetadata"): + return _Accumulator.get_class().postprocess_dumps(physical_dumps, physical_to_logical_map, expert_location_metadata) # --------------------------------------- SinglePassGatherer ----------------------------------------- @@ -107,14 +107,14 @@ def postprocess_dumps(physical_dumps: List[Any], physical_to_logical_map: torch. class _SinglePassGatherer(ABC): @staticmethod - def init_new(server_args: ServerArgs, metadata: "ModelExpertMetadata") -> "_SinglePassGatherer": + def init_new(server_args: ServerArgs, expert_location_metadata: "ExpertLocationMetadata") -> "_SinglePassGatherer": if server_args.enable_deepep_moe: # TODO DeepEP low latency - return _DeepepNormalSinglePassGatherer(metadata) - return _LayerBasedSinglePassGatherer(metadata) + return _DeepepNormalSinglePassGatherer(expert_location_metadata) + return _LayerBasedSinglePassGatherer(expert_location_metadata) - def __init__(self, metadata: "ModelExpertMetadata"): - self._metadata = metadata + def __init__(self, expert_location_metadata: "ExpertLocationMetadata"): + self._expert_location_metadata = expert_location_metadata def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor): pass @@ -130,14 +130,14 @@ def collect(self) -> torch.Tensor: class _LayerBasedSinglePassGatherer(_SinglePassGatherer): - def __init__(self, metadata: "ModelExpertMetadata"): - super().__init__(metadata) + def __init__(self, expert_location_metadata: "ExpertLocationMetadata"): + super().__init__(expert_location_metadata) self._num_recv_tokens_per_expert_list_of_layer = {} def _on_layer_data(self, layer_idx: int, num_recv_tokens_per_expert_list: List[int]): # TODO for TBO, we may need to relax this restriction assert layer_idx not in self._num_recv_tokens_per_expert_list_of_layer - assert 0 <= layer_idx < self._metadata.num_layers + assert 0 <= layer_idx < self._expert_location_metadata.num_layers self._num_recv_tokens_per_expert_list_of_layer[layer_idx] = num_recv_tokens_per_expert_list def reset(self): @@ -146,8 +146,8 @@ def reset(self): def collect(self) -> torch.Tensor: data = [ self._num_recv_tokens_per_expert_list_of_layer.get(layer_index) or ( - [0] * self._metadata.num_local_physical_experts) - for layer_index in range(self._metadata.num_layers) + [0] * self._expert_location_metadata.num_local_physical_experts) + for layer_index in range(self._expert_location_metadata.num_layers) ] return torch.tensor(data) @@ -158,7 +158,7 @@ def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor): topk_ids_list = topk_ids.to("cpu", non_blocking=True).numpy().tolist() torch.cuda.synchronize() - num_recv_tokens_per_expert_list = [0] * self._metadata.num_local_physical_experts + num_recv_tokens_per_expert_list = [0] * self._expert_location_metadata.num_local_physical_experts for token_record in topk_ids_list: for expert_idx in token_record: num_recv_tokens_per_expert_list[expert_idx] += 1 @@ -186,8 +186,8 @@ class _DeepepLowLatencySinglePassGatherer(_SinglePassGatherer): class _Accumulator(ABC): @staticmethod - def init_new(metadata: "ModelExpertMetadata", rank: int) -> "_Accumulator": - return _Accumulator.get_class()(metadata, rank) + def init_new(expert_location_metadata: "ExpertLocationMetadata", rank: int) -> "_Accumulator": + return _Accumulator.get_class()(expert_location_metadata, rank) @staticmethod def get_class() -> Type["_Accumulator"]: @@ -195,8 +195,8 @@ def get_class() -> Type["_Accumulator"]: return _DetailAccumulator return _StatAccumulator - def __init__(self, metadata: "ModelExpertMetadata", rank: int): - self._metadata = metadata + def __init__(self, expert_location_metadata: "ExpertLocationMetadata", rank: int): + self._expert_location_metadata = expert_location_metadata self._rank = rank def get_single_pass_gatherer_keys(self): @@ -207,7 +207,7 @@ def get_single_pass_gatherer_key(self, debug_name: Optional[str]): @classmethod def postprocess_dumps(cls, physical_dumps: List[Any], physical_to_logical_map: torch.Tensor, - metadata: "ModelExpertMetadata"): + expert_location_metadata: "ExpertLocationMetadata"): raise NotImplementedError def append(self, forward_pass_id: int, gatherer_key: str, single_pass_physical_count: torch.Tensor): @@ -223,7 +223,7 @@ def dump(self): class _DetailAccumulator(_Accumulator): @classmethod def postprocess_dumps(cls, physical_dumps: List[Any], physical_to_logical_map: torch.Tensor, - metadata: "ModelExpertMetadata"): + expert_location_metadata: "ExpertLocationMetadata"): # Do not convert to logical since we want all details return [ record @@ -231,8 +231,8 @@ def postprocess_dumps(cls, physical_dumps: List[Any], physical_to_logical_map: t for record in physical_dump ] - def __init__(self, metadata: "ModelExpertMetadata", rank: int): - super().__init__(metadata, rank) + def __init__(self, expert_location_metadata: "ExpertLocationMetadata", rank: int): + super().__init__(expert_location_metadata, rank) self._records = [] def get_single_pass_gatherer_keys(self): @@ -263,22 +263,23 @@ def dump(self): class _StatAccumulator(_Accumulator): @classmethod def postprocess_dumps(cls, physical_dumps: List[Any], physical_to_logical_map: torch.Tensor, - metadata: "ModelExpertMetadata"): - logical_count = torch.zeros((metadata.num_layers, metadata.num_logical_experts)) + expert_location_metadata: "ExpertLocationMetadata"): + logical_count = torch.zeros((expert_location_metadata.num_layers, expert_location_metadata.num_logical_experts)) # Most naive implementation, can optimize if it is bottleneck for physical_dump in physical_dumps: - for layer_index in range(metadata.num_layers): - for local_physical_expert_index in range(metadata.num_local_physical_experts): - global_physical_expert_index = metadata.num_local_physical_experts * physical_dump[ + for layer_index in range(expert_location_metadata.num_layers): + for local_physical_expert_index in range(expert_location_metadata.num_local_physical_experts): + global_physical_expert_index = expert_location_metadata.num_local_physical_experts * physical_dump[ 'rank'] + local_physical_expert_index logical_expert_index = physical_to_logical_map[layer_index, global_physical_expert_index] logical_count[layer_index, logical_expert_index] += physical_dump['physical_count'][ layer_index, local_physical_expert_index] return dict(logical_count=logical_count) - def __init__(self, metadata: "ModelExpertMetadata", rank: int): - super().__init__(metadata, rank) - self._physical_count = torch.zeros((self._metadata.num_layers, self._metadata.num_local_physical_experts)) + def __init__(self, expert_location_metadata: "ExpertLocationMetadata", rank: int): + super().__init__(expert_location_metadata, rank) + self._physical_count = torch.zeros( + (self._expert_location_metadata.num_layers, self._expert_location_metadata.num_local_physical_experts)) def append(self, forward_pass_id: int, gatherer_key: str, single_pass_physical_count: torch.Tensor): self._physical_count += single_pass_physical_count @@ -291,26 +292,3 @@ def dump(self): rank=self._rank, physical_count=self._physical_count.tolist(), ) - - -# --------------------------------------- Misc ----------------------------------------- - -@dataclass -class ModelExpertMetadata: - num_layers: int - num_local_physical_experts: int - num_logical_experts: int - - @staticmethod - def from_model(model): - if hasattr(model, "get_model_expert_metadata"): - return model.get_model_expert_metadata() - return ModelExpertMetadata._init_dummy() - - @staticmethod - def _init_dummy(): - return ModelExpertMetadata( - num_layers=1, - num_local_physical_experts=1, - num_logical_experts=1, - ) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py new file mode 100644 index 00000000000..57d46c41cb6 --- /dev/null +++ b/python/sglang/srt/managers/expert_location.py @@ -0,0 +1,22 @@ +from dataclasses import dataclass + + +@dataclass +class ExpertLocationMetadata: + num_layers: int + num_local_physical_experts: int + num_logical_experts: int + + @staticmethod + def from_model(model): + if hasattr(model, "get_model_expert_metadata"): + return model.get_model_expert_metadata() + return ExpertLocationMetadata._init_dummy() + + @staticmethod + def _init_dummy(): + return ExpertLocationMetadata( + num_layers=1, + num_local_physical_experts=1, + num_logical_experts=1, + ) From ffda545d03473a84f9cdcf12ae6f47e24f9bddfc Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 10:18:09 +0800 Subject: [PATCH 0188/1089] more --- python/sglang/srt/managers/expert_location.py | 8 ++++++-- python/sglang/srt/models/deepseek_v2.py | 5 +++-- python/sglang/srt/models/qwen2_moe.py | 5 +++-- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 57d46c41cb6..05f1b43a0b3 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -9,10 +9,14 @@ class ExpertLocationMetadata: @staticmethod def from_model(model): - if hasattr(model, "get_model_expert_metadata"): - return model.get_model_expert_metadata() + if hasattr(model, "get_expert_location_metadata"): + return model.get_expert_location_metadata() return ExpertLocationMetadata._init_dummy() + @staticmethod + def init_new(**kwargs): + return ExpertLocationMetadata(**kwargs) + @staticmethod def _init_dummy(): return ExpertLocationMetadata( diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 0c0e2de0b6a..eb029dfd182 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -67,6 +67,7 @@ VocabParallelEmbedding, ) from sglang.srt.managers.expert_distribution import global_expert_distribution_recorder, ModelExpertMetadata +from sglang.srt.managers.expert_location import ExpertLocationMetadata from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_loader.weight_utils import default_weight_loader @@ -1491,8 +1492,8 @@ def set_embed_and_head(self, embed, head): torch.cuda.empty_cache() torch.cuda.synchronize() - def get_model_expert_metadata(self): - return ModelExpertMetadata( + def get_expert_location_metadata(self): + return ExpertLocationMetadata.init_new( num_layers=self.config.num_hidden_layers, num_logical_experts=self.config.n_routed_experts, # TODO handle more complex cases like duplicating experts on different GPUs diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index fd65e25f050..e16651ba769 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -42,6 +42,7 @@ VocabParallelEmbedding, ) from sglang.srt.managers.expert_distribution import global_expert_distribution_recorder, ModelExpertMetadata +from sglang.srt.managers.expert_location import ExpertLocationMetadata from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import add_prefix @@ -484,8 +485,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ) weight_loader(param, loaded_weight) - def get_model_expert_metadata(self): - return ModelExpertMetadata( + def get_expert_location_metadata(self): + return ExpertLocationMetadata.init_new( num_layers=self.config.num_hidden_layers, num_logical_experts=self.config.num_experts, # TODO handle more complex cases like duplicating experts on different GPUs From 61d5198808dbfe68f6b73254530c3696c2638b11 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 10:18:43 +0800 Subject: [PATCH 0189/1089] more --- python/sglang/srt/managers/expert_location.py | 12 ++++++++++-- python/sglang/srt/models/qwen2_moe.py | 2 -- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 05f1b43a0b3..6e87eb215de 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -1,5 +1,7 @@ from dataclasses import dataclass +from sglang.srt.distributed import get_tensor_model_parallel_world_size + @dataclass class ExpertLocationMetadata: @@ -14,8 +16,14 @@ def from_model(model): return ExpertLocationMetadata._init_dummy() @staticmethod - def init_new(**kwargs): - return ExpertLocationMetadata(**kwargs) + def init_new(num_layers: int, num_logical_experts: int): + # TODO handle more complex cases like duplicating experts on different GPUs + num_local_physical_experts = num_logical_experts // get_tensor_model_parallel_world_size() + return ExpertLocationMetadata( + num_layers=num_layers, + num_logical_experts=num_local_physical_experts, + num_local_physical_experts=num_local_physical_experts, + ) @staticmethod def _init_dummy(): diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index e16651ba769..11741de1d30 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -489,8 +489,6 @@ def get_expert_location_metadata(self): return ExpertLocationMetadata.init_new( num_layers=self.config.num_hidden_layers, num_logical_experts=self.config.num_experts, - # TODO handle more complex cases like duplicating experts on different GPUs - num_local_physical_experts=self.config.num_experts // get_tensor_model_parallel_world_size(), ) EntryClass = Qwen2MoeForCausalLM From 7525f70ce29fbbce5919c86869a359198931ba83 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 10:19:11 +0800 Subject: [PATCH 0190/1089] more --- python/sglang/srt/managers/expert_distribution.py | 14 +++++--------- python/sglang/srt/managers/tokenizer_manager.py | 2 +- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 42041e25c6a..7fe42c0e519 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -97,9 +97,8 @@ def dump_record(self): global_expert_distribution_recorder: Optional[ExpertDistributionRecorder] = None -def postprocess_dumps(physical_dumps: List[Any], physical_to_logical_map: torch.Tensor, - expert_location_metadata: "ExpertLocationMetadata"): - return _Accumulator.get_class().postprocess_dumps(physical_dumps, physical_to_logical_map, expert_location_metadata) +def postprocess_dumps(physical_dumps: List[Any], expert_location_metadata: "ExpertLocationMetadata"): + return _Accumulator.get_class().postprocess_dumps(physical_dumps, expert_location_metadata) # --------------------------------------- SinglePassGatherer ----------------------------------------- @@ -206,8 +205,7 @@ def get_single_pass_gatherer_key(self, debug_name: Optional[str]): return _SINGLE_PASS_GATHERER_KEY_PRIMARY @classmethod - def postprocess_dumps(cls, physical_dumps: List[Any], physical_to_logical_map: torch.Tensor, - expert_location_metadata: "ExpertLocationMetadata"): + def postprocess_dumps(cls, physical_dumps: List[Any], expert_location_metadata: "ExpertLocationMetadata"): raise NotImplementedError def append(self, forward_pass_id: int, gatherer_key: str, single_pass_physical_count: torch.Tensor): @@ -222,8 +220,7 @@ def dump(self): class _DetailAccumulator(_Accumulator): @classmethod - def postprocess_dumps(cls, physical_dumps: List[Any], physical_to_logical_map: torch.Tensor, - expert_location_metadata: "ExpertLocationMetadata"): + def postprocess_dumps(cls, physical_dumps: List[Any], expert_location_metadata: "ExpertLocationMetadata"): # Do not convert to logical since we want all details return [ record @@ -262,8 +259,7 @@ def dump(self): class _StatAccumulator(_Accumulator): @classmethod - def postprocess_dumps(cls, physical_dumps: List[Any], physical_to_logical_map: torch.Tensor, - expert_location_metadata: "ExpertLocationMetadata"): + def postprocess_dumps(cls, physical_dumps: List[Any], expert_location_metadata: "ExpertLocationMetadata"): logical_count = torch.zeros((expert_location_metadata.num_layers, expert_location_metadata.num_logical_experts)) # Most naive implementation, can optimize if it is bottleneck for physical_dump in physical_dumps: diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 517e3a926ce..1ab0262bade 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -656,7 +656,7 @@ async def dump_expert_distribution_record(self): raw_outputs: List[ExpertDistributionReqOutput] = await self.expert_distribution_communicator( ExpertDistributionReq.DUMP_RECORD) return expert_distribution.postprocess_dumps([output.dump_output for output in raw_outputs], - physical_to_logical_map=TODO, metadata=TODO) + expert_location_metadata=TODO) async def update_weights_from_disk( self, From 739f0b385aef027e4dc5215efaba70644eed0d62 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 10:19:23 +0800 Subject: [PATCH 0191/1089] more --- python/sglang/srt/managers/expert_location.py | 1 + python/sglang/srt/models/deepseek_v2.py | 2 -- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 6e87eb215de..6f0cc8d2fab 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -19,6 +19,7 @@ def from_model(model): def init_new(num_layers: int, num_logical_experts: int): # TODO handle more complex cases like duplicating experts on different GPUs num_local_physical_experts = num_logical_experts // get_tensor_model_parallel_world_size() + return ExpertLocationMetadata( num_layers=num_layers, num_logical_experts=num_local_physical_experts, diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index eb029dfd182..dccb4054b3d 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1496,8 +1496,6 @@ def get_expert_location_metadata(self): return ExpertLocationMetadata.init_new( num_layers=self.config.num_hidden_layers, num_logical_experts=self.config.n_routed_experts, - # TODO handle more complex cases like duplicating experts on different GPUs - num_local_physical_experts=self.config.n_routed_experts // get_tensor_model_parallel_world_size(), ) class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM): From 05ed6b94b84ce2ef83d7c942479b8e15fdc23974 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 10:20:10 +0800 Subject: [PATCH 0192/1089] more --- python/sglang/srt/managers/expert_distribution.py | 3 ++- python/sglang/srt/managers/expert_location.py | 4 ++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 7fe42c0e519..7478cf1dacc 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -267,7 +267,8 @@ def postprocess_dumps(cls, physical_dumps: List[Any], expert_location_metadata: for local_physical_expert_index in range(expert_location_metadata.num_local_physical_experts): global_physical_expert_index = expert_location_metadata.num_local_physical_experts * physical_dump[ 'rank'] + local_physical_expert_index - logical_expert_index = physical_to_logical_map[layer_index, global_physical_expert_index] + logical_expert_index = expert_location_metadata.physical_to_logical_map[ + layer_index, global_physical_expert_index] logical_count[layer_index, logical_expert_index] += physical_dump['physical_count'][ layer_index, local_physical_expert_index] return dict(logical_count=logical_count) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 6f0cc8d2fab..ee05566e8ec 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -1,5 +1,6 @@ from dataclasses import dataclass +import torch from sglang.srt.distributed import get_tensor_model_parallel_world_size @@ -8,6 +9,9 @@ class ExpertLocationMetadata: num_layers: int num_local_physical_experts: int num_logical_experts: int + physical_to_logical_map: torch.Tensor # (layers, num_physical_experts) + + # will have a `logical_to_physical_map` as well @staticmethod def from_model(model): From 3a87c40ade7a433ff9064e4a4bf7b5f535aa7abf Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 10:20:39 +0800 Subject: [PATCH 0193/1089] more --- python/sglang/srt/managers/expert_location.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index ee05566e8ec..467dfed2f21 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -28,12 +28,13 @@ def init_new(num_layers: int, num_logical_experts: int): num_layers=num_layers, num_logical_experts=num_local_physical_experts, num_local_physical_experts=num_local_physical_experts, + physical_to_logical_map=_create_vanilla_physical_to_logical_map(), ) @staticmethod def _init_dummy(): - return ExpertLocationMetadata( - num_layers=1, - num_local_physical_experts=1, - num_logical_experts=1, - ) + return ExpertLocationMetadata.init_new(num_layers=1, num_logical_experts=1) + + +def _create_vanilla_physical_to_logical_map(): + return TODO From bbd55ce0c92cee4c2c23d916ae6c67f7d9a0c153 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 10:21:22 +0800 Subject: [PATCH 0194/1089] more --- python/sglang/srt/managers/expert_location.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 467dfed2f21..3ccb2f905ae 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -26,9 +26,12 @@ def init_new(num_layers: int, num_logical_experts: int): return ExpertLocationMetadata( num_layers=num_layers, - num_logical_experts=num_local_physical_experts, + num_logical_experts=num_logical_experts, num_local_physical_experts=num_local_physical_experts, - physical_to_logical_map=_create_vanilla_physical_to_logical_map(), + physical_to_logical_map=_create_vanilla_physical_to_logical_map( + num_layers=num_layers, + num_physical_experts=num_local_physical_experts * TODO, + ), ) @staticmethod @@ -36,5 +39,5 @@ def _init_dummy(): return ExpertLocationMetadata.init_new(num_layers=1, num_logical_experts=1) -def _create_vanilla_physical_to_logical_map(): - return TODO +def _create_vanilla_physical_to_logical_map(num_layers: int, num_physical_experts: int): + return torch.arange(0, num_physical_experts).repeat(num_layers, 1) From 9f67203c36d5ef033b86b9638f4c21c484a671fe Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 10:21:34 +0800 Subject: [PATCH 0195/1089] more --- python/sglang/srt/managers/expert_location.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 3ccb2f905ae..8dc30ff30e9 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -23,6 +23,7 @@ def from_model(model): def init_new(num_layers: int, num_logical_experts: int): # TODO handle more complex cases like duplicating experts on different GPUs num_local_physical_experts = num_logical_experts // get_tensor_model_parallel_world_size() + num_physical_experts = num_logical_experts return ExpertLocationMetadata( num_layers=num_layers, @@ -30,7 +31,7 @@ def init_new(num_layers: int, num_logical_experts: int): num_local_physical_experts=num_local_physical_experts, physical_to_logical_map=_create_vanilla_physical_to_logical_map( num_layers=num_layers, - num_physical_experts=num_local_physical_experts * TODO, + num_physical_experts=num_physical_experts, ), ) From 329927f6c576a966642600841a34abb7c75bd395 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 10:21:45 +0800 Subject: [PATCH 0196/1089] more --- python/sglang/srt/managers/expert_location.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 8dc30ff30e9..69e31f8f530 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -9,10 +9,9 @@ class ExpertLocationMetadata: num_layers: int num_local_physical_experts: int num_logical_experts: int + # will have a `logical_to_physical_map` later physical_to_logical_map: torch.Tensor # (layers, num_physical_experts) - # will have a `logical_to_physical_map` as well - @staticmethod def from_model(model): if hasattr(model, "get_expert_location_metadata"): From ab5e79948c7e4050db1e8980b549f96eb730a08c Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 10:24:34 +0800 Subject: [PATCH 0197/1089] more --- python/sglang/srt/managers/scheduler.py | 1 + python/sglang/srt/managers/tp_worker.py | 1 + python/sglang/srt/model_executor/model_runner.py | 5 +++-- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index eb3023a808a..dc673f49a23 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -276,6 +276,7 @@ def __init__( self.random_seed, self.device, worker_global_server_args_dict, + self.expert_location_metadata, _, _, _, diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 174f2e53321..f53941d056d 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -145,6 +145,7 @@ def get_worker_info(self): self.random_seed, self.device, global_server_args_dict, + self.model_runner.expert_location_metadata, self.model_runner.req_to_token_pool.size, self.model_runner.req_to_token_pool.max_context_len, self.model_runner.token_to_kv_pool.size, diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index b9a7128dcd9..5d32f6d5f33 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -46,6 +46,7 @@ from sglang.srt.lora.lora_manager import LoRAManager from sglang.srt.managers.expert_distribution import global_expert_distribution_recorder, ExpertDistributionRecorder, \ ModelExpertMetadata +from sglang.srt.managers.expert_location import ExpertLocationMetadata from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.mem_cache.memory_pool import ( DoubleSparseTokenToKVPool, @@ -216,9 +217,9 @@ def initialize(self, min_per_gpu_memory: float): if self.spec_algorithm.is_eagle3() and not self.is_draft_worker: self.model.set_eagle3_layers_to_capture() - model_expert_metadata = ModelExpertMetadata.from_model(self.model) + self.expert_location_metadata = ExpertLocationMetadata.from_model(self.model) self.expert_distribution_recorder = ExpertDistributionRecorder( - server_args, model_expert_metadata, + server_args, expert_location_metadata, # TODO handle DP!=TP case rank=self.tp_rank, ) From 1d24fc2f0e3502778fa8c5f8a0ac735d3999e429 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 10:25:17 +0800 Subject: [PATCH 0198/1089] more --- python/sglang/srt/managers/scheduler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index dc673f49a23..9d0317c53a9 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -2001,6 +2001,7 @@ def run_scheduler_process( "status": "ready", "max_total_num_tokens": scheduler.max_total_num_tokens, "max_req_input_len": scheduler.max_req_input_len, + "expert_location_metadata": scheduler.expert_location_metadata, } ) disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode From d3f6e7b88efc1f71376827b2cfdc2be226c0fbc2 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 10:25:44 +0800 Subject: [PATCH 0199/1089] more --- python/sglang/srt/model_executor/model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 5d32f6d5f33..623456d5a5c 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -219,7 +219,7 @@ def initialize(self, min_per_gpu_memory: float): self.expert_location_metadata = ExpertLocationMetadata.from_model(self.model) self.expert_distribution_recorder = ExpertDistributionRecorder( - server_args, expert_location_metadata, + server_args, self.expert_location_metadata, # TODO handle DP!=TP case rank=self.tp_rank, ) From b14482be8d418c0a80120a468420e0d03ab3903f Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 10:27:20 +0800 Subject: [PATCH 0200/1089] more --- python/sglang/srt/entrypoints/engine.py | 1 + python/sglang/srt/managers/tokenizer_manager.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index afc46c50f3b..fb1f55d6bf9 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -574,4 +574,5 @@ def _launch_subprocesses( # Assume all schedulers have the same scheduler_info scheduler_info = scheduler_infos[0] tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"] + tokenizer_manager.expert_location_metadata = scheduler_info["expert_location_metadata"] return tokenizer_manager, scheduler_info diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 1ab0262bade..31d3ab5b82a 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -229,6 +229,7 @@ def __init__( # Set after scheduler is initialized self.max_req_input_len = None + self.expert_location_metadata = None # Metrics if self.enable_metrics: @@ -656,7 +657,7 @@ async def dump_expert_distribution_record(self): raw_outputs: List[ExpertDistributionReqOutput] = await self.expert_distribution_communicator( ExpertDistributionReq.DUMP_RECORD) return expert_distribution.postprocess_dumps([output.dump_output for output in raw_outputs], - expert_location_metadata=TODO) + expert_location_metadata=self.expert_location_metadata) async def update_weights_from_disk( self, From 7d8ecfad9cb18be88bf1ebfce7ac7aba62ffcea2 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 10:28:21 +0800 Subject: [PATCH 0201/1089] more --- python/sglang/srt/managers/data_parallel_controller.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index fb0264a6ea9..6eef426c640 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -100,6 +100,7 @@ def __init__(self, server_args: ServerArgs, port_args: PortArgs) -> None: ) self.max_req_input_len = None + self.expert_location_metadata = None def launch_dp_schedulers(self, server_args, port_args): base_gpu_id = 0 @@ -219,6 +220,7 @@ def launch_tensor_parallel_group( self.max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"] self.max_req_input_len = scheduler_info[0]["max_req_input_len"] + self.expert_location_metadata = scheduler_info[0]["expert_location_metadata"] def round_robin_scheduler(self, req): self.workers[self.round_robin_counter].send_pyobj(req) @@ -265,6 +267,7 @@ def run_data_parallel_controller_process( "status": "ready", "max_total_num_tokens": controller.max_total_num_tokens, "max_req_input_len": controller.max_req_input_len, + "expert_location_metadata": controller.expert_location_metadata, } ) if server_args.node_rank == 0: From e3ca65f5f8eb77db4a8cb63d3f6f08ad8fb4c407 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 10:29:52 +0800 Subject: [PATCH 0202/1089] more --- python/sglang/srt/model_executor/model_runner.py | 5 ++--- python/sglang/srt/models/deepseek_v2.py | 2 +- python/sglang/srt/models/qwen2_moe.py | 2 +- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 623456d5a5c..24782ed90bb 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -44,8 +44,7 @@ from sglang.srt.layers.sampler import Sampler from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model from sglang.srt.lora.lora_manager import LoRAManager -from sglang.srt.managers.expert_distribution import global_expert_distribution_recorder, ExpertDistributionRecorder, \ - ModelExpertMetadata +from sglang.srt.managers.expert_distribution import global_expert_distribution_recorder, ExpertDistributionRecorder from sglang.srt.managers.expert_location import ExpertLocationMetadata from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.mem_cache.memory_pool import ( @@ -985,7 +984,7 @@ def forward( self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False ) -> LogitsProcessorOutput: self.forward_pass_id += 1 - with expert_distribution_recorder.with_forward_pass(self.forward_pass_id): + with global_expert_distribution_recorder.with_forward_pass(self.forward_pass_id): return self._forward_raw(forward_batch, skip_attn_backend_init) def _forward_raw(self, forward_batch: ForwardBatch, skip_attn_backend_init: bool) -> LogitsProcessorOutput: diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index dccb4054b3d..0cd7e81f152 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -66,7 +66,7 @@ ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.managers.expert_distribution import global_expert_distribution_recorder, ModelExpertMetadata +from sglang.srt.managers.expert_distribution import global_expert_distribution_recorder from sglang.srt.managers.expert_location import ExpertLocationMetadata from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 11741de1d30..9ed06e7384e 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -41,7 +41,7 @@ ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.managers.expert_distribution import global_expert_distribution_recorder, ModelExpertMetadata +from sglang.srt.managers.expert_distribution import global_expert_distribution_recorder from sglang.srt.managers.expert_location import ExpertLocationMetadata from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader From 3d5610afb88fe734102859a143a50c217e22f78c Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 10:32:08 +0800 Subject: [PATCH 0203/1089] more --- python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py | 4 ++-- python/sglang/srt/layers/moe/topk.py | 4 ++-- python/sglang/srt/managers/expert_distribution.py | 3 +-- python/sglang/srt/managers/scheduler.py | 8 ++++---- python/sglang/srt/model_executor/model_runner.py | 2 +- python/sglang/srt/models/deepseek_v2.py | 4 ++-- python/sglang/srt/models/qwen2_moe.py | 4 ++-- 7 files changed, 14 insertions(+), 15 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index 42fa548f77f..4bf33649cea 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -1,4 +1,4 @@ -from sglang.srt.managers.expert_distribution import global_expert_distribution_recorder +from sglang.srt.managers.expert_distribution import expert_distribution_recorder try: from deep_ep import Buffer @@ -198,7 +198,7 @@ def dispatch( handle, event, ) = self.dispatch_normal(hidden_states, topk_idx, topk_weights, num_experts) - global_expert_distribution_recorder.on_deepep_dispatch_normal(num_recv_tokens_per_expert_list) + expert_distribution_recorder.on_deepep_dispatch_normal(num_recv_tokens_per_expert_list) self.tokens_per_expert = torch.tensor( num_recv_tokens_per_expert_list, device=hidden_states.device, diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 170960a4f83..247901d64f3 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -17,7 +17,7 @@ import torch import torch.nn.functional as F -from sglang.srt.managers.expert_distribution import global_expert_distribution_recorder +from sglang.srt.managers.expert_distribution import expert_distribution_recorder from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip _is_cuda = is_cuda() @@ -253,6 +253,6 @@ def select_experts( renormalize=renormalize, ) - global_expert_distribution_recorder.on_select_experts(topk_ids=topk_ids) + expert_distribution_recorder.on_select_experts(topk_ids=topk_ids) return topk_weights, topk_ids diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 7478cf1dacc..2c80c3a6a81 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -93,8 +93,7 @@ def dump_record(self): return output -# Put global args for easy access, just like `global_server_args_dict` -global_expert_distribution_recorder: Optional[ExpertDistributionRecorder] = None +expert_distribution_recorder: Optional[ExpertDistributionRecorder] = None def postprocess_dumps(physical_dumps: List[Any], expert_location_metadata: "ExpertLocationMetadata"): diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 9d0317c53a9..3d7c4d2728e 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -51,7 +51,7 @@ from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.layers.dp_attention import compute_dp_attention_world_info from sglang.srt.layers.logits_processor import LogitsProcessorOutput -from sglang.srt.managers.expert_distribution import global_expert_distribution_recorder +from sglang.srt.managers.expert_distribution import expert_distribution_recorder from sglang.srt.managers.io_struct import ( AbortReq, CloseSessionReqInput, @@ -1910,11 +1910,11 @@ def stop_profile(self) -> None: def expert_distribution_handle(self, recv_req: ExpertDistributionReq): dump_output = None if recv_req == ExpertDistributionReq.START_RECORD: - global_expert_distribution_recorder.start_record() + expert_distribution_recorder.start_record() elif recv_req == ExpertDistributionReq.STOP_RECORD: - global_expert_distribution_recorder.stop_record() + expert_distribution_recorder.stop_record() elif recv_req == ExpertDistributionReq.DUMP_RECORD: - dump_output = global_expert_distribution_recorder.dump_record() + dump_output = expert_distribution_recorder.dump_record() else: raise ValueError("Unrecognized ExpertDistributionReq value") return ExpertDistributionReqOutput(dump_output=dump_output) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 24782ed90bb..ca1b969476c 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -44,7 +44,7 @@ from sglang.srt.layers.sampler import Sampler from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model from sglang.srt.lora.lora_manager import LoRAManager -from sglang.srt.managers.expert_distribution import global_expert_distribution_recorder, ExpertDistributionRecorder +from sglang.srt.managers.expert_distribution import expert_distribution_recorder, ExpertDistributionRecorder from sglang.srt.managers.expert_location import ExpertLocationMetadata from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.mem_cache.memory_pool import ( diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 0cd7e81f152..c0712769e9c 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -66,7 +66,7 @@ ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.managers.expert_distribution import global_expert_distribution_recorder +from sglang.srt.managers.expert_distribution import expert_distribution_recorder from sglang.srt.managers.expert_location import ExpertLocationMetadata from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode @@ -1271,7 +1271,7 @@ def forward( residual = None for i in range(len(self.layers)): - with global_expert_distribution_recorder.with_current_layer(i): + with expert_distribution_recorder.with_current_layer(i): layer = self.layers[i] hidden_states, residual = layer( positions, hidden_states, forward_batch, residual diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 9ed06e7384e..018374360b7 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -41,7 +41,7 @@ ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.managers.expert_distribution import global_expert_distribution_recorder +from sglang.srt.managers.expert_distribution import expert_distribution_recorder from sglang.srt.managers.expert_location import ExpertLocationMetadata from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader @@ -367,7 +367,7 @@ def forward( hidden_states = input_embeds residual = None for i in range(len(self.layers)): - with global_expert_distribution_recorder.with_current_layer(i): + with expert_distribution_recorder.with_current_layer(i): layer = self.layers[i] hidden_states, residual = layer( positions, hidden_states, forward_batch, residual From 9f2d0447d5523b26a93ea460c652092751e64f19 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 10:32:32 +0800 Subject: [PATCH 0204/1089] more --- python/sglang/srt/managers/expert_distribution.py | 4 ++-- python/sglang/srt/model_executor/model_runner.py | 4 +--- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 2c80c3a6a81..e41d6837488 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -17,7 +17,7 @@ class ExpertDistributionRecorder: """Global expert distribution recording""" - def __init__(self, server_args: ServerArgs, expert_location_metadata: "ExpertLocationMetadata", rank: int): + def initialize(self, server_args: ServerArgs, expert_location_metadata: "ExpertLocationMetadata", rank: int): self._recording = False self._current_layer_idx = Withable() self._current_debug_name = Withable() @@ -93,7 +93,7 @@ def dump_record(self): return output -expert_distribution_recorder: Optional[ExpertDistributionRecorder] = None +expert_distribution_recorder = ExpertDistributionRecorder() def postprocess_dumps(physical_dumps: List[Any], expert_location_metadata: "ExpertLocationMetadata"): diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index ca1b969476c..2eaa9c44f44 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -217,13 +217,11 @@ def initialize(self, min_per_gpu_memory: float): self.model.set_eagle3_layers_to_capture() self.expert_location_metadata = ExpertLocationMetadata.from_model(self.model) - self.expert_distribution_recorder = ExpertDistributionRecorder( + expert_distribution_recorder.initialize( server_args, self.expert_location_metadata, # TODO handle DP!=TP case rank=self.tp_rank, ) - global global_expert_distribution_recorder - global_expert_distribution_recorder = self.expert_distribution_recorder def model_specific_adjustment(self): server_args = self.server_args From 6f16852430d0b1a5b6df455f79f30ce0c5554d92 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 10:34:06 +0800 Subject: [PATCH 0205/1089] more --- python/sglang/srt/model_executor/model_runner.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 2eaa9c44f44..7357e3bec33 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -181,6 +181,13 @@ def initialize(self, min_per_gpu_memory: float): self.sampler = Sampler() self.load_model() + self.expert_location_metadata = ExpertLocationMetadata.from_model(self.model) + expert_distribution_recorder.initialize( + server_args, self.expert_location_metadata, + # TODO handle DP!=TP case + rank=self.tp_rank, + ) + # Apply torchao quantization torchao_applied = getattr(self.model, "torchao_applied", False) # In layered loading, torchao may have been applied @@ -216,13 +223,6 @@ def initialize(self, min_per_gpu_memory: float): if self.spec_algorithm.is_eagle3() and not self.is_draft_worker: self.model.set_eagle3_layers_to_capture() - self.expert_location_metadata = ExpertLocationMetadata.from_model(self.model) - expert_distribution_recorder.initialize( - server_args, self.expert_location_metadata, - # TODO handle DP!=TP case - rank=self.tp_rank, - ) - def model_specific_adjustment(self): server_args = self.server_args From 1306e2d88a61578968a167dd3742e449e2ae5a1b Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 10:35:26 +0800 Subject: [PATCH 0206/1089] more --- python/sglang/srt/model_executor/model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 7357e3bec33..374998fe596 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -982,7 +982,7 @@ def forward( self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False ) -> LogitsProcessorOutput: self.forward_pass_id += 1 - with global_expert_distribution_recorder.with_forward_pass(self.forward_pass_id): + with expert_distribution_recorder.with_forward_pass(self.forward_pass_id): return self._forward_raw(forward_batch, skip_attn_backend_init) def _forward_raw(self, forward_batch: ForwardBatch, skip_attn_backend_init: bool) -> LogitsProcessorOutput: From fe482a1309b5cf8f73223d688a18703541dd1175 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 10:37:08 +0800 Subject: [PATCH 0207/1089] more --- python/sglang/srt/managers/expert_distribution.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index e41d6837488..6f663e9a2ac 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -269,7 +269,7 @@ def postprocess_dumps(cls, physical_dumps: List[Any], expert_location_metadata: logical_expert_index = expert_location_metadata.physical_to_logical_map[ layer_index, global_physical_expert_index] logical_count[layer_index, logical_expert_index] += physical_dump['physical_count'][ - layer_index, local_physical_expert_index] + layer_index][local_physical_expert_index] return dict(logical_count=logical_count) def __init__(self, expert_location_metadata: "ExpertLocationMetadata", rank: int): From 4f869a21e186af2d038c176c7e2843f6c2f4c1ac Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 10:39:58 +0800 Subject: [PATCH 0208/1089] more --- python/sglang/srt/managers/expert_distribution.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 6f663e9a2ac..6f2ff80a309 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -270,7 +270,7 @@ def postprocess_dumps(cls, physical_dumps: List[Any], expert_location_metadata: layer_index, global_physical_expert_index] logical_count[layer_index, logical_expert_index] += physical_dump['physical_count'][ layer_index][local_physical_expert_index] - return dict(logical_count=logical_count) + return dict(logical_count=logical_count.tolist()) def __init__(self, expert_location_metadata: "ExpertLocationMetadata", rank: int): super().__init__(expert_location_metadata, rank) From 5a17b1720de25450ca124ddaee85a80fd180fcb5 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 10:44:14 +0800 Subject: [PATCH 0209/1089] more --- python/sglang/srt/model_executor/model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 374998fe596..73f2300bcb2 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -44,7 +44,7 @@ from sglang.srt.layers.sampler import Sampler from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model from sglang.srt.lora.lora_manager import LoRAManager -from sglang.srt.managers.expert_distribution import expert_distribution_recorder, ExpertDistributionRecorder +from sglang.srt.managers.expert_distribution import expert_distribution_recorder from sglang.srt.managers.expert_location import ExpertLocationMetadata from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.mem_cache.memory_pool import ( From bb79fedfe67f2f37a8478e47ad74feeb10591bd2 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 10:45:51 +0800 Subject: [PATCH 0210/1089] more --- python/sglang/srt/managers/expert_distribution.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 6f2ff80a309..9dbfcd00d96 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -14,7 +14,7 @@ # --------------------------------------- Entrypoint ----------------------------------------- -class ExpertDistributionRecorder: +class _ExpertDistributionRecorder: """Global expert distribution recording""" def initialize(self, server_args: ServerArgs, expert_location_metadata: "ExpertLocationMetadata", rank: int): @@ -42,6 +42,8 @@ def with_forward_pass(self, forward_pass_id: int): self._on_forward_pass_end(forward_pass_id) def _on_forward_pass_end(self, forward_pass_id: int): + if not self._recording: + return for gatherer_key, gatherer in self._single_pass_gatherers.items(): single_pass_physical_count = gatherer.collect() self._accumulator.append(forward_pass_id, gatherer_key, single_pass_physical_count) @@ -93,7 +95,7 @@ def dump_record(self): return output -expert_distribution_recorder = ExpertDistributionRecorder() +expert_distribution_recorder = _ExpertDistributionRecorder() def postprocess_dumps(physical_dumps: List[Any], expert_location_metadata: "ExpertLocationMetadata"): From 486ec0f5a1689e2fae26655874803254a3019792 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 10:46:22 +0800 Subject: [PATCH 0211/1089] more --- python/sglang/srt/managers/expert_distribution.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 9dbfcd00d96..f5aa19da7d1 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -111,7 +111,7 @@ def init_new(server_args: ServerArgs, expert_location_metadata: "ExpertLocationM if server_args.enable_deepep_moe: # TODO DeepEP low latency return _DeepepNormalSinglePassGatherer(expert_location_metadata) - return _LayerBasedSinglePassGatherer(expert_location_metadata) + return _SelectExpertsSinglePassGatherer(expert_location_metadata) def __init__(self, expert_location_metadata: "ExpertLocationMetadata"): self._expert_location_metadata = expert_location_metadata From 5c3a4764dd3391db8ac95d662d02e29d28a2320c Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 10:48:38 +0800 Subject: [PATCH 0212/1089] fmt --- python/sglang/srt/entrypoints/engine.py | 4 +- python/sglang/srt/entrypoints/http_server.py | 34 ++-- .../srt/layers/moe/ep_moe/token_dispatcher.py | 4 +- python/sglang/srt/layers/moe/topk.py | 1 - .../srt/managers/expert_distribution.py | 160 +++++++++++++----- python/sglang/srt/managers/expert_location.py | 5 +- python/sglang/srt/managers/scheduler.py | 25 +-- .../sglang/srt/managers/tokenizer_manager.py | 18 +- .../sglang/srt/model_executor/model_runner.py | 12 +- python/sglang/srt/models/deepseek_v2.py | 1 + python/sglang/srt/models/qwen2_moe.py | 8 +- python/sglang/srt/utils.py | 18 +- test/srt/test_expert_distribution.py | 12 +- 13 files changed, 208 insertions(+), 94 deletions(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index fb1f55d6bf9..5345a6d84f0 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -574,5 +574,7 @@ def _launch_subprocesses( # Assume all schedulers have the same scheduler_info scheduler_info = scheduler_infos[0] tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"] - tokenizer_manager.expert_location_metadata = scheduler_info["expert_location_metadata"] + tokenizer_manager.expert_location_metadata = scheduler_info[ + "expert_location_metadata" + ] return tokenizer_manager, scheduler_info diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 94e59978221..f43afec5423 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -231,7 +231,7 @@ async def generate_request(obj: GenerateReqInput, request: Request): async def stream_results() -> AsyncIterator[bytes]: try: async for out in _global_state.tokenizer_manager.generate_request( - obj, request + obj, request ): yield b"data: " + orjson.dumps( out, option=orjson.OPT_NON_STR_KEYS @@ -313,7 +313,7 @@ async def flush_cache(): _global_state.tokenizer_manager.flush_cache() return Response( content="Cache flushed.\nPlease check backend logs for more details. " - "(When there are running or waiting requests, the operation will not be performed.)\n", + "(When there are running or waiting requests, the operation will not be performed.)\n", status_code=200, ) @@ -395,7 +395,7 @@ async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: R @app.post("/init_weights_update_group") async def init_weights_update_group( - obj: InitWeightsUpdateGroupReqInput, request: Request + obj: InitWeightsUpdateGroupReqInput, request: Request ): """Initialize the parameter update group.""" success, message = await _global_state.tokenizer_manager.init_weights_update_group( @@ -410,7 +410,7 @@ async def init_weights_update_group( @app.post("/update_weights_from_distributed") async def update_weights_from_distributed( - obj: UpdateWeightsFromDistributedReqInput, request: Request + obj: UpdateWeightsFromDistributedReqInput, request: Request ): """Update model parameter from distributed online.""" success, message = ( @@ -440,7 +440,7 @@ async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request): @app.api_route("/release_memory_occupation", methods=["GET", "POST"]) async def release_memory_occupation( - obj: ReleaseMemoryOccupationReqInput, request: Request + obj: ReleaseMemoryOccupationReqInput, request: Request ): """Release GPU memory occupation temporarily.""" try: @@ -451,7 +451,7 @@ async def release_memory_occupation( @app.api_route("/resume_memory_occupation", methods=["GET", "POST"]) async def resume_memory_occupation( - obj: ResumeMemoryOccupationReqInput, request: Request + obj: ResumeMemoryOccupationReqInput, request: Request ): """Resume GPU memory occupation.""" try: @@ -634,10 +634,10 @@ async def vertex_generate(vertex_req: VertexGenerateReqInput, raw_request: Reque ] break image_data = [ - instance.get("image_data") - for instance in vertex_req.instances - if instance.get("image_data") is not None - ] or None + instance.get("image_data") + for instance in vertex_req.instances + if instance.get("image_data") is not None + ] or None req = GenerateReqInput( **inputs, image_data=image_data, @@ -654,9 +654,9 @@ def _create_error_response(e): def launch_server( - server_args: ServerArgs, - pipe_finish_writer: Optional[multiprocessing.connection.Connection] = None, - launch_callback: Optional[Callable[[], None]] = None, + server_args: ServerArgs, + pipe_finish_writer: Optional[multiprocessing.connection.Connection] = None, + launch_callback: Optional[Callable[[], None]] = None, ): """ Launch SRT (SGLang Runtime) Server. @@ -721,10 +721,10 @@ def launch_server( def _wait_and_warmup( - server_args: ServerArgs, - pipe_finish_writer: Optional[multiprocessing.connection.Connection], - image_token_text: str, - launch_callback: Optional[Callable[[], None]] = None, + server_args: ServerArgs, + pipe_finish_writer: Optional[multiprocessing.connection.Connection], + image_token_text: str, + launch_callback: Optional[Callable[[], None]] = None, ): headers = {} url = server_args.url() diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index 4bf33649cea..14146af99e3 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -198,7 +198,9 @@ def dispatch( handle, event, ) = self.dispatch_normal(hidden_states, topk_idx, topk_weights, num_experts) - expert_distribution_recorder.on_deepep_dispatch_normal(num_recv_tokens_per_expert_list) + expert_distribution_recorder.on_deepep_dispatch_normal( + num_recv_tokens_per_expert_list + ) self.tokens_per_expert = torch.tensor( num_recv_tokens_per_expert_list, device=hidden_states.device, diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 247901d64f3..e30b7545452 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -24,7 +24,6 @@ _is_hip = is_hip() - def fused_topk_native( hidden_states: torch.Tensor, gating_output: torch.Tensor, diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index f5aa19da7d1..976bd9358ee 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -2,9 +2,10 @@ from abc import ABC from contextlib import contextmanager from copy import deepcopy -from typing import List, Type, Any, Optional +from typing import Any, List, Optional, Type import torch + from sglang.srt.managers.expert_location import ExpertLocationMetadata from sglang.srt.server_args import ServerArgs from sglang.srt.utils import Withable, get_bool_env_var @@ -14,10 +15,16 @@ # --------------------------------------- Entrypoint ----------------------------------------- + class _ExpertDistributionRecorder: """Global expert distribution recording""" - def initialize(self, server_args: ServerArgs, expert_location_metadata: "ExpertLocationMetadata", rank: int): + def initialize( + self, + server_args: ServerArgs, + expert_location_metadata: "ExpertLocationMetadata", + rank: int, + ): self._recording = False self._current_layer_idx = Withable() self._current_debug_name = Withable() @@ -46,20 +53,28 @@ def _on_forward_pass_end(self, forward_pass_id: int): return for gatherer_key, gatherer in self._single_pass_gatherers.items(): single_pass_physical_count = gatherer.collect() - self._accumulator.append(forward_pass_id, gatherer_key, single_pass_physical_count) + self._accumulator.append( + forward_pass_id, gatherer_key, single_pass_physical_count + ) gatherer.reset() def on_select_experts(self, topk_ids: torch.Tensor): self._on_hook("on_select_experts", topk_ids=topk_ids) def on_deepep_dispatch_normal(self, num_recv_tokens_per_expert_list: List[int]): - self._on_hook("on_deepep_dispatch_normal", num_recv_tokens_per_expert_list=num_recv_tokens_per_expert_list) + self._on_hook( + "on_deepep_dispatch_normal", + num_recv_tokens_per_expert_list=num_recv_tokens_per_expert_list, + ) def _on_hook(self, hook_name: str, **kwargs): if not self._recording: return gatherer = self._single_pass_gatherers[ - self._accumulator.get_single_pass_gatherer_key(self._current_debug_name.value)] + self._accumulator.get_single_pass_gatherer_key( + self._current_debug_name.value + ) + ] getattr(gatherer, hook_name)(layer_idx=self._current_layer_idx.value, **kwargs) def _reset(self): @@ -98,8 +113,12 @@ def dump_record(self): expert_distribution_recorder = _ExpertDistributionRecorder() -def postprocess_dumps(physical_dumps: List[Any], expert_location_metadata: "ExpertLocationMetadata"): - return _Accumulator.get_class().postprocess_dumps(physical_dumps, expert_location_metadata) +def postprocess_dumps( + physical_dumps: List[Any], expert_location_metadata: "ExpertLocationMetadata" +): + return _Accumulator.get_class().postprocess_dumps( + physical_dumps, expert_location_metadata + ) # --------------------------------------- SinglePassGatherer ----------------------------------------- @@ -107,7 +126,9 @@ def postprocess_dumps(physical_dumps: List[Any], expert_location_metadata: "Expe class _SinglePassGatherer(ABC): @staticmethod - def init_new(server_args: ServerArgs, expert_location_metadata: "ExpertLocationMetadata") -> "_SinglePassGatherer": + def init_new( + server_args: ServerArgs, expert_location_metadata: "ExpertLocationMetadata" + ) -> "_SinglePassGatherer": if server_args.enable_deepep_moe: # TODO DeepEP low latency return _DeepepNormalSinglePassGatherer(expert_location_metadata) @@ -119,7 +140,9 @@ def __init__(self, expert_location_metadata: "ExpertLocationMetadata"): def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor): pass - def on_deepep_dispatch_normal(self, layer_idx: int, num_recv_tokens_per_expert_list: List[int]): + def on_deepep_dispatch_normal( + self, layer_idx: int, num_recv_tokens_per_expert_list: List[int] + ): pass def reset(self): @@ -134,19 +157,23 @@ def __init__(self, expert_location_metadata: "ExpertLocationMetadata"): super().__init__(expert_location_metadata) self._num_recv_tokens_per_expert_list_of_layer = {} - def _on_layer_data(self, layer_idx: int, num_recv_tokens_per_expert_list: List[int]): + def _on_layer_data( + self, layer_idx: int, num_recv_tokens_per_expert_list: List[int] + ): # TODO for TBO, we may need to relax this restriction assert layer_idx not in self._num_recv_tokens_per_expert_list_of_layer assert 0 <= layer_idx < self._expert_location_metadata.num_layers - self._num_recv_tokens_per_expert_list_of_layer[layer_idx] = num_recv_tokens_per_expert_list + self._num_recv_tokens_per_expert_list_of_layer[layer_idx] = ( + num_recv_tokens_per_expert_list + ) def reset(self): self._num_recv_tokens_per_expert_list_of_layer.clear() def collect(self) -> torch.Tensor: data = [ - self._num_recv_tokens_per_expert_list_of_layer.get(layer_index) or ( - [0] * self._expert_location_metadata.num_local_physical_experts) + self._num_recv_tokens_per_expert_list_of_layer.get(layer_index) + or ([0] * self._expert_location_metadata.num_local_physical_experts) for layer_index in range(self._expert_location_metadata.num_layers) ] return torch.tensor(data) @@ -158,7 +185,9 @@ def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor): topk_ids_list = topk_ids.to("cpu", non_blocking=True).numpy().tolist() torch.cuda.synchronize() - num_recv_tokens_per_expert_list = [0] * self._expert_location_metadata.num_local_physical_experts + num_recv_tokens_per_expert_list = [ + 0 + ] * self._expert_location_metadata.num_local_physical_experts for token_record in topk_ids_list: for expert_idx in token_record: num_recv_tokens_per_expert_list[expert_idx] += 1 @@ -167,7 +196,9 @@ def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor): class _DeepepNormalSinglePassGatherer(_LayerBasedSinglePassGatherer): - def on_deepep_dispatch_normal(self, layer_idx: int, num_recv_tokens_per_expert_list: List[int]): + def on_deepep_dispatch_normal( + self, layer_idx: int, num_recv_tokens_per_expert_list: List[int] + ): assert isinstance(num_recv_tokens_per_expert_list, list) self._on_layer_data(layer_idx, num_recv_tokens_per_expert_list) @@ -186,7 +217,9 @@ class _DeepepLowLatencySinglePassGatherer(_SinglePassGatherer): class _Accumulator(ABC): @staticmethod - def init_new(expert_location_metadata: "ExpertLocationMetadata", rank: int) -> "_Accumulator": + def init_new( + expert_location_metadata: "ExpertLocationMetadata", rank: int + ) -> "_Accumulator": return _Accumulator.get_class()(expert_location_metadata, rank) @staticmethod @@ -206,10 +239,19 @@ def get_single_pass_gatherer_key(self, debug_name: Optional[str]): return _SINGLE_PASS_GATHERER_KEY_PRIMARY @classmethod - def postprocess_dumps(cls, physical_dumps: List[Any], expert_location_metadata: "ExpertLocationMetadata"): + def postprocess_dumps( + cls, + physical_dumps: List[Any], + expert_location_metadata: "ExpertLocationMetadata", + ): raise NotImplementedError - def append(self, forward_pass_id: int, gatherer_key: str, single_pass_physical_count: torch.Tensor): + def append( + self, + forward_pass_id: int, + gatherer_key: str, + single_pass_physical_count: torch.Tensor, + ): raise NotImplementedError def reset(self): @@ -221,13 +263,13 @@ def dump(self): class _DetailAccumulator(_Accumulator): @classmethod - def postprocess_dumps(cls, physical_dumps: List[Any], expert_location_metadata: "ExpertLocationMetadata"): + def postprocess_dumps( + cls, + physical_dumps: List[Any], + expert_location_metadata: "ExpertLocationMetadata", + ): # Do not convert to logical since we want all details - return [ - record - for physical_dump in physical_dumps - for record in physical_dump - ] + return [record for physical_dump in physical_dumps for record in physical_dump] def __init__(self, expert_location_metadata: "ExpertLocationMetadata", rank: int): super().__init__(expert_location_metadata, rank) @@ -243,13 +285,20 @@ def get_single_pass_gatherer_key(self, debug_name: Optional[str]): return debug_name or _SINGLE_PASS_GATHERER_KEY_PRIMARY return super().get_single_pass_gatherer_key(debug_name) - def append(self, forward_pass_id: int, gatherer_key: str, single_pass_physical_count: torch.Tensor): - self._records.append(dict( - forward_pass_id=forward_pass_id, - rank=self._rank, - gatherer_key=gatherer_key, - physical_count=single_pass_physical_count.tolist(), - )) + def append( + self, + forward_pass_id: int, + gatherer_key: str, + single_pass_physical_count: torch.Tensor, + ): + self._records.append( + dict( + forward_pass_id=forward_pass_id, + rank=self._rank, + gatherer_key=gatherer_key, + physical_count=single_pass_physical_count.tolist(), + ) + ) def reset(self): self._records.clear() @@ -260,26 +309,53 @@ def dump(self): class _StatAccumulator(_Accumulator): @classmethod - def postprocess_dumps(cls, physical_dumps: List[Any], expert_location_metadata: "ExpertLocationMetadata"): - logical_count = torch.zeros((expert_location_metadata.num_layers, expert_location_metadata.num_logical_experts)) + def postprocess_dumps( + cls, + physical_dumps: List[Any], + expert_location_metadata: "ExpertLocationMetadata", + ): + logical_count = torch.zeros( + ( + expert_location_metadata.num_layers, + expert_location_metadata.num_logical_experts, + ) + ) # Most naive implementation, can optimize if it is bottleneck for physical_dump in physical_dumps: for layer_index in range(expert_location_metadata.num_layers): - for local_physical_expert_index in range(expert_location_metadata.num_local_physical_experts): - global_physical_expert_index = expert_location_metadata.num_local_physical_experts * physical_dump[ - 'rank'] + local_physical_expert_index - logical_expert_index = expert_location_metadata.physical_to_logical_map[ - layer_index, global_physical_expert_index] - logical_count[layer_index, logical_expert_index] += physical_dump['physical_count'][ - layer_index][local_physical_expert_index] + for local_physical_expert_index in range( + expert_location_metadata.num_local_physical_experts + ): + global_physical_expert_index = ( + expert_location_metadata.num_local_physical_experts + * physical_dump["rank"] + + local_physical_expert_index + ) + logical_expert_index = ( + expert_location_metadata.physical_to_logical_map[ + layer_index, global_physical_expert_index + ] + ) + logical_count[layer_index, logical_expert_index] += physical_dump[ + "physical_count" + ][layer_index][local_physical_expert_index] return dict(logical_count=logical_count.tolist()) def __init__(self, expert_location_metadata: "ExpertLocationMetadata", rank: int): super().__init__(expert_location_metadata, rank) self._physical_count = torch.zeros( - (self._expert_location_metadata.num_layers, self._expert_location_metadata.num_local_physical_experts)) + ( + self._expert_location_metadata.num_layers, + self._expert_location_metadata.num_local_physical_experts, + ) + ) - def append(self, forward_pass_id: int, gatherer_key: str, single_pass_physical_count: torch.Tensor): + def append( + self, + forward_pass_id: int, + gatherer_key: str, + single_pass_physical_count: torch.Tensor, + ): self._physical_count += single_pass_physical_count def reset(self): diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 69e31f8f530..c533784f2cd 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -1,6 +1,7 @@ from dataclasses import dataclass import torch + from sglang.srt.distributed import get_tensor_model_parallel_world_size @@ -21,7 +22,9 @@ def from_model(model): @staticmethod def init_new(num_layers: int, num_logical_experts: int): # TODO handle more complex cases like duplicating experts on different GPUs - num_local_physical_experts = num_logical_experts // get_tensor_model_parallel_world_size() + num_local_physical_experts = ( + num_logical_experts // get_tensor_model_parallel_world_size() + ) num_physical_experts = num_logical_experts return ExpertLocationMetadata( diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 3d7c4d2728e..43ca7728753 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -32,6 +32,8 @@ import setproctitle import torch import zmq +from torch.distributed import barrier + from sglang.global_config import global_config from sglang.srt.configs.model_config import ModelConfig from sglang.srt.constrained.base_grammar_backend import create_grammar_backend @@ -128,7 +130,6 @@ suppress_other_loggers, ) from sglang.utils import TypeBasedDispatcher, get_exception_traceback -from torch.distributed import barrier logger = logging.getLogger(__name__) @@ -359,8 +360,8 @@ def __init__( 1.0, ) self.new_token_ratio_decay = ( - self.init_new_token_ratio - self.min_new_token_ratio - ) / global_config.default_new_token_ratio_decay_steps + self.init_new_token_ratio - self.min_new_token_ratio + ) / global_config.default_new_token_ratio_decay_steps self.new_token_ratio = self.init_new_token_ratio # Init watchdog thread @@ -1236,10 +1237,10 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: if ( self.lora_paths and len( - lora_set - | set([req.lora_path for req in adder.can_run_list]) - | set([req.lora_path]) - ) + lora_set + | set([req.lora_path for req in adder.can_run_list]) + | set([req.lora_path]) + ) > self.max_loras_per_batch ): self.running_batch.batch_is_full = True @@ -1264,9 +1265,9 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: self.running_batch.batch_is_full = len( adder.can_run_list ) > 0 or ( - self.running_batch is not None - and not self.running_batch.is_empty() - ) + self.running_batch is not None + and not self.running_batch.is_empty() + ) else: self.running_batch.batch_is_full = True break @@ -1470,8 +1471,8 @@ def prepare_dp_attn_batch(self, local_batch: ScheduleBatch): # We should have at least 1 token for sample in every case. max(extend_len - logprob_start_len, 1) for logprob_start_len, extend_len in zip( - local_batch.extend_logprob_start_lens, local_batch.extend_lens - ) + local_batch.extend_logprob_start_lens, local_batch.extend_lens + ) ] ) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 31d3ab5b82a..8bb082c738b 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -45,6 +45,7 @@ import zmq import zmq.asyncio from fastapi import BackgroundTasks + from sglang.srt.aio_rwlock import RWLock from sglang.srt.configs.model_config import ModelConfig from sglang.srt.disaggregation.conn import KVBootstrapServer @@ -654,10 +655,15 @@ async def stop_expert_distribution_record(self): await self.expert_distribution_communicator(ExpertDistributionReq.STOP_RECORD) async def dump_expert_distribution_record(self): - raw_outputs: List[ExpertDistributionReqOutput] = await self.expert_distribution_communicator( - ExpertDistributionReq.DUMP_RECORD) - return expert_distribution.postprocess_dumps([output.dump_output for output in raw_outputs], - expert_location_metadata=self.expert_location_metadata) + raw_outputs: List[ExpertDistributionReqOutput] = ( + await self.expert_distribution_communicator( + ExpertDistributionReq.DUMP_RECORD + ) + ) + return expert_distribution.postprocess_dumps( + [output.dump_output for output in raw_outputs], + expert_location_metadata=self.expert_location_metadata, + ) async def update_weights_from_disk( self, @@ -961,8 +967,8 @@ def _handle_batch_output( elif isinstance(recv_obj, BatchTokenIDOut): if self.server_args.stream_output and state.obj.stream: output_token_ids = recv_obj.output_ids[i][ - state.last_output_offset: - ] + state.last_output_offset : + ] state.last_output_offset = len(recv_obj.output_ids[i]) else: output_token_ids = recv_obj.output_ids[i] diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 73f2300bcb2..544e76a9269 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -24,6 +24,7 @@ import torch import torch.distributed as dist + from sglang.srt.configs.device_config import DeviceConfig from sglang.srt.configs.load_config import LoadConfig from sglang.srt.configs.model_config import AttentionArch, ModelConfig @@ -163,7 +164,7 @@ def __init__( ) # CPU offload - set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024 ** 3)) + set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3)) # Get memory before model loading min_per_gpu_memory = self.init_torch_distributed() @@ -183,7 +184,8 @@ def initialize(self, min_per_gpu_memory: float): self.expert_location_metadata = ExpertLocationMetadata.from_model(self.model) expert_distribution_recorder.initialize( - server_args, self.expert_location_metadata, + server_args, + self.expert_location_metadata, # TODO handle DP!=TP case rank=self.tp_rank, ) @@ -904,7 +906,7 @@ def init_double_sparsity_channel_config(self, selected_channel): key = "model.layers." + str(i) + ".self_attn" + selected_channel self.sorted_channels.append( torch.tensor(channel_config[key])[ - :, : self.server_args.ds_heavy_channel_num + :, : self.server_args.ds_heavy_channel_num ] .contiguous() .cuda() @@ -985,7 +987,9 @@ def forward( with expert_distribution_recorder.with_forward_pass(self.forward_pass_id): return self._forward_raw(forward_batch, skip_attn_backend_init) - def _forward_raw(self, forward_batch: ForwardBatch, skip_attn_backend_init: bool) -> LogitsProcessorOutput: + def _forward_raw( + self, forward_batch: ForwardBatch, skip_attn_backend_init: bool + ) -> LogitsProcessorOutput: if ( forward_batch.forward_mode.is_cuda_graph() and self.cuda_graph_runner diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index c0712769e9c..18bf19e3e04 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1498,6 +1498,7 @@ def get_expert_location_metadata(self): num_logical_experts=self.config.n_routed_experts, ) + class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM): pass diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 018374360b7..8c518213311 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -20,6 +20,9 @@ import torch import torch.nn.functional as F +from torch import nn +from transformers import PretrainedConfig + from sglang.srt.distributed import ( get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, @@ -46,8 +49,6 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import add_prefix -from torch import nn -from transformers import PretrainedConfig class Qwen2MoeMLP(nn.Module): @@ -194,7 +195,7 @@ def __init__( self.head_dim = hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim - self.scaling = self.head_dim ** -0.5 + self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings @@ -491,4 +492,5 @@ def get_expert_location_metadata(self): num_logical_experts=self.config.num_experts, ) + EntryClass = Qwen2MoeForCausalLM diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index e013230b764..3d5509248ca 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -43,7 +43,19 @@ from io import BytesIO from multiprocessing.reduction import ForkingPickler from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Protocol, Set, Tuple, Union, TypeVar, Generic +from typing import ( + Any, + Callable, + Dict, + Generic, + List, + Optional, + Protocol, + Set, + Tuple, + TypeVar, + Union, +) import numpy as np import psutil @@ -1835,7 +1847,8 @@ def flatten_nested_list(nested_list): else: return [nested_list] -T = TypeVar('T') + +T = TypeVar("T") class Withable(Generic[T]): @@ -1855,4 +1868,3 @@ def with_value(self, new_value: T): finally: assert self._value is new_value self._value = None - diff --git a/test/srt/test_expert_distribution.py b/test/srt/test_expert_distribution.py index 546b0c872a6..3812d2e068a 100755 --- a/test/srt/test_expert_distribution.py +++ b/test/srt/test_expert_distribution.py @@ -3,6 +3,7 @@ import requests import torch + from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST, @@ -17,7 +18,10 @@ class TestExpertDistribution(CustomTestCase): def test_expert_distribution_record(self): # TODO: Add tests for DeepEP gatherer (currently our CI cannot run that) for info in [ - dict(model_path="deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct", mode_detail=False), + dict( + model_path="deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct", + mode_detail=False, + ), dict(model_path="Qwen/Qwen1.5-MoE-A2.7B", mode_detail=False), dict(model_path="Qwen/Qwen1.5-MoE-A2.7B", mode_detail=True), ]: @@ -26,7 +30,9 @@ def test_expert_distribution_record(self): def _execute_core(self, model_path: str, mode_detail: bool): """Test expert distribution record endpoints""" - os.environ["SGLANG_EXPERT_DISTRIBUTION_RECORDER_DETAIL"] = "1" if mode_detail else "0" + os.environ["SGLANG_EXPERT_DISTRIBUTION_RECORDER_DETAIL"] = ( + "1" if mode_detail else "0" + ) process = popen_launch_server( model_path, DEFAULT_URL_FOR_TEST, @@ -75,7 +81,7 @@ def _execute_core(self, model_path: str, mode_detail: bool): if mode_detail: self.assertGreater(len(data), 0, "Should contain data rows") else: - logical_count = torch.tensor(data['logical_count']) + logical_count = torch.tensor(data["logical_count"]) print(f"{logical_count=}") self.assertTrue(logical_count.sum() > 0) From cabeddb14da44a6126cda54e659a9d7036cc81fa Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 10:51:25 +0800 Subject: [PATCH 0213/1089] bump ci --- test/srt/test_expert_distribution.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/srt/test_expert_distribution.py b/test/srt/test_expert_distribution.py index 3812d2e068a..50acda9077a 100755 --- a/test/srt/test_expert_distribution.py +++ b/test/srt/test_expert_distribution.py @@ -56,7 +56,7 @@ def _execute_core(self, model_path: str, mode_detail: bool): "text": "The capital of France is", "sampling_params": { "temperature": 0, - "max_new_tokens": 3, + "max_new_tokens": 32, }, }, ) From 7821819286ce71b1e836958ebeafe694d9ef5e9e Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 11:05:38 +0800 Subject: [PATCH 0214/1089] fix ci --- python/sglang/srt/managers/expert_distribution.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 976bd9358ee..3356228eea4 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -19,15 +19,17 @@ class _ExpertDistributionRecorder: """Global expert distribution recording""" + def __init__(self): + self._recording = False + self._current_layer_idx = Withable() + self._current_debug_name = Withable() + def initialize( self, server_args: ServerArgs, expert_location_metadata: "ExpertLocationMetadata", rank: int, ): - self._recording = False - self._current_layer_idx = Withable() - self._current_debug_name = Withable() self._expert_location_metadata = expert_location_metadata self._accumulator = _Accumulator.init_new(expert_location_metadata, rank) self._single_pass_gatherers = { From ec98f75f54353db4ca021e154b9cec9220083ae5 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 14:52:32 +0800 Subject: [PATCH 0215/1089] fix ci --- docs/backend/native_api.ipynb | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/docs/backend/native_api.ipynb b/docs/backend/native_api.ipynb index 72b65c6ca98..c02c53b1c8b 100644 --- a/docs/backend/native_api.ipynb +++ b/docs/backend/native_api.ipynb @@ -406,15 +406,7 @@ "print_highlight(response)\n", "\n", "response = requests.post(f\"http://localhost:{port}/dump_expert_distribution_record\")\n", - "print_highlight(response)\n", - "\n", - "import glob\n", - "\n", - "output_file = glob.glob(\"expert_distribution_*.csv\")[0]\n", - "with open(output_file, \"r\") as f:\n", - " print_highlight(\"Content of dumped record:\")\n", - " for line in f:\n", - " print_highlight(line.strip())" + "print_highlight(response)\n" ] }, { From 1e5dfbc528879d1d541c64836ba179374da9dbd5 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 1 Apr 2025 21:13:42 +0800 Subject: [PATCH 0216/1089] fmt --- docs/backend/native_api.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/backend/native_api.ipynb b/docs/backend/native_api.ipynb index c02c53b1c8b..3961358fb3f 100644 --- a/docs/backend/native_api.ipynb +++ b/docs/backend/native_api.ipynb @@ -406,7 +406,7 @@ "print_highlight(response)\n", "\n", "response = requests.post(f\"http://localhost:{port}/dump_expert_distribution_record\")\n", - "print_highlight(response)\n" + "print_highlight(response)" ] }, { From a2236ec7bf29fe2c0af8dc1a2a35001cef7f5c6c Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 9 Apr 2025 15:44:10 +0800 Subject: [PATCH 0217/1089] more --- python/sglang/srt/models/deepseek_v2.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index d5f3321260e..55e6b084e32 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -279,10 +279,7 @@ def forward( return self.forward_deepep(hidden_states, forward_mode) def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor: - if self.n_shared_experts is not None and self.n_share_experts_fusion == 0: - shared_output = self.shared_experts(hidden_states) - else: - shared_output = None + shared_output = self._forward_shared_experts(hidden_states) # router_logits: (num_tokens, n_experts) router_logits = self.gate(hidden_states) final_hidden_states = ( @@ -312,8 +309,7 @@ def forward_deepep( ): # router_logits: (num_tokens, n_experts) router_logits = self.gate(hidden_states) - if self.n_shared_experts is not None: - shared_output = self.shared_experts(hidden_states) + shared_output = self._forward_shared_experts(hidden_states) topk_weights, topk_idx = select_experts( hidden_states=hidden_states, router_logits=router_logits, @@ -363,6 +359,12 @@ def forward_deepep( return final_hidden_states + def _forward_shared_experts(self, hidden_states): + if self.n_shared_experts is not None and self.n_share_experts_fusion == 0: + return self.shared_experts(hidden_states) + else: + return None + def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: import math From d62f2b3dbd689e6ba8b2593258358e044a5000ff Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 9 Apr 2025 21:28:48 +0800 Subject: [PATCH 0218/1089] doc --- python/sglang/srt/server_args.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 69994520b24..1853515b487 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1107,7 +1107,7 @@ def add_cli_args(parser: argparse.ArgumentParser): "--moe-dense-tp-size", type=int, default=ServerArgs.moe_dense_tp_size, - help="tp_size for MoE dense MLP layers", + help="TP size for MoE dense MLP layers. This flag is useful when, with large TP size, there are errors caused by weights in MLP layers having dimension smaller than the min dimension GEMM supports.", ) parser.add_argument( "--deepep-mode", From c6aadb8a1752bc3dc519c11596daea0669dd6a7f Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 9 Apr 2025 21:33:50 +0800 Subject: [PATCH 0219/1089] rename --- python/sglang/srt/models/deepseek_v2.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 417fa2b7ca9..981e0bdd428 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1014,8 +1014,10 @@ def forward_absorb_fused_mla_rope( class _DecoderLayerExecutionMode(Enum): - MLP_ONE = auto() - MLP_ALL = auto() + # The MLP sublayer requires 1/tp_size tokens as input + MLP_INPUT_ONE = auto() + # The MLP sublayer requires all tokens as input + MLP_INPUT_ALL = auto() @dataclass @@ -1113,7 +1115,8 @@ def __init__( ) self.input_is_scattered = ( - previous_layer_info.execution_mode == _DecoderLayerExecutionMode.MLP_ONE + previous_layer_info.execution_mode + == _DecoderLayerExecutionMode.MLP_INPUT_ONE ) self.is_last_layer = self.layer_id == config.num_hidden_layers - 1 @@ -1134,10 +1137,10 @@ def _compute_info(config: PretrainedConfig, layer_id: int, is_nextn: bool): and layer_id % config.moe_layer_freq == 0 ) execution_mode = ( - _DecoderLayerExecutionMode.MLP_ONE + _DecoderLayerExecutionMode.MLP_INPUT_ONE if (global_server_args_dict["enable_deepep_moe"] and is_sparse) or (DeepseekV2DecoderLayer._enable_moe_dense_fully_dp() and not is_sparse) - else _DecoderLayerExecutionMode.MLP_ALL + else _DecoderLayerExecutionMode.MLP_INPUT_ALL ) return _DecoderLayerInfo(is_sparse=is_sparse, execution_mode=execution_mode) @@ -1148,18 +1151,18 @@ def forward( forward_batch: ForwardBatch, residual: Optional[torch.Tensor], ) -> torch.Tensor: - if self.info.execution_mode == _DecoderLayerExecutionMode.MLP_ONE: - return self.forward_mode_mlp_one( + if self.info.execution_mode == _DecoderLayerExecutionMode.MLP_INPUT_ONE: + return self.forward_mode_mlp_input_one( positions, hidden_states, forward_batch, residual ) - elif self.info.execution_mode == _DecoderLayerExecutionMode.MLP_ALL: - return self.forward_mode_mlp_all( + elif self.info.execution_mode == _DecoderLayerExecutionMode.MLP_INPUT_ALL: + return self.forward_mode_mlp_input_all( positions, hidden_states, forward_batch, residual ) else: raise NotImplementedError - def forward_mode_mlp_all( + def forward_mode_mlp_input_all( self, positions: torch.Tensor, hidden_states: torch.Tensor, @@ -1226,7 +1229,7 @@ def forward_mode_mlp_all( return hidden_states, residual - def forward_mode_mlp_one( + def forward_mode_mlp_input_one( self, positions: torch.Tensor, hidden_states: torch.Tensor, From 054bbaf32282baf6a2af16099ec15e91719eec30 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 9 Apr 2025 21:36:12 +0800 Subject: [PATCH 0220/1089] fix diff --- python/sglang/srt/server_args.py | 72 ++++++++++++++++---------------- 1 file changed, 36 insertions(+), 36 deletions(-) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 1853515b487..6b833f443d7 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -184,7 +184,7 @@ class ServerArgs: flashinfer_mla_disable_ragged: bool = False warmups: Optional[str] = None moe_dense_tp_size: Optional[int] = None - n_share_experts_fusion: Optional[int] = None + n_share_experts_fusion: int = 0 disable_shared_experts_fusion: bool = False # Debug tensor dumps @@ -420,8 +420,8 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.tokenizer_mode, choices=["auto", "slow"], help="Tokenizer mode. 'auto' will use the fast " - "tokenizer if available, and 'slow' will " - "always use the slow tokenizer.", + "tokenizer if available, and 'slow' will " + "always use the slow tokenizer.", ) parser.add_argument( "--skip-tokenizer-init", @@ -445,21 +445,21 @@ def add_cli_args(parser: argparse.ArgumentParser): "remote", ], help="The format of the model weights to load. " - '"auto" will try to load the weights in the safetensors format ' - "and fall back to the pytorch bin format if safetensors format " - "is not available. " - '"pt" will load the weights in the pytorch bin format. ' - '"safetensors" will load the weights in the safetensors format. ' - '"npcache" will load the weights in pytorch format and store ' - "a numpy cache to speed up the loading. " - '"dummy" will initialize the weights with random values, ' - "which is mainly for profiling." - '"gguf" will load the weights in the gguf format. ' - '"bitsandbytes" will load the weights using bitsandbytes ' - "quantization." - '"layered" loads weights layer by layer so that one can quantize a ' - "layer before loading another to make the peak memory envelope " - "smaller.", + '"auto" will try to load the weights in the safetensors format ' + "and fall back to the pytorch bin format if safetensors format " + "is not available. " + '"pt" will load the weights in the pytorch bin format. ' + '"safetensors" will load the weights in the safetensors format. ' + '"npcache" will load the weights in pytorch format and store ' + "a numpy cache to speed up the loading. " + '"dummy" will initialize the weights with random values, ' + "which is mainly for profiling." + '"gguf" will load the weights in the gguf format. ' + '"bitsandbytes" will load the weights using bitsandbytes ' + "quantization." + '"layered" loads weights layer by layer so that one can quantize a ' + "layer before loading another to make the peak memory envelope " + "smaller.", ) parser.add_argument( "--trust-remote-code", @@ -472,13 +472,13 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.dtype, choices=["auto", "half", "float16", "bfloat16", "float", "float32"], help="Data type for model weights and activations.\n\n" - '* "auto" will use FP16 precision for FP32 and FP16 models, and ' - "BF16 precision for BF16 models.\n" - '* "half" for FP16. Recommended for AWQ quantization.\n' - '* "float16" is the same as "half".\n' - '* "bfloat16" for a balance between precision and range.\n' - '* "float" is shorthand for FP32 precision.\n' - '* "float32" for FP32 precision.', + '* "auto" will use FP16 precision for FP32 and FP16 models, and ' + "BF16 precision for BF16 models.\n" + '* "half" for FP16. Recommended for AWQ quantization.\n' + '* "float16" is the same as "half".\n' + '* "bfloat16" for a balance between precision and range.\n' + '* "float" is shorthand for FP32 precision.\n' + '* "float32" for FP32 precision.', ) parser.add_argument( "--kv-cache-dtype", @@ -513,9 +513,9 @@ def add_cli_args(parser: argparse.ArgumentParser): type=nullable_str, default=None, help="Path to the JSON file containing the KV cache " - "scaling factors. This should generally be supplied, when " - "KV cache dtype is FP8. Otherwise, KV cache scaling factors " - "default to 1.0, which may cause accuracy issues. ", + "scaling factors. This should generally be supplied, when " + "KV cache dtype is FP8. Otherwise, KV cache scaling factors " + "default to 1.0, which may cause accuracy issues. ", ) parser.add_argument( "--context-length", @@ -557,8 +557,8 @@ def add_cli_args(parser: argparse.ArgumentParser): type=str, default=None, help="The specific model version to use. It can be a branch " - "name, a tag name, or a commit id. If unspecified, will use " - "the default version.", + "name, a tag name, or a commit id. If unspecified, will use " + "the default version.", ) # Memory and scheduling parser.add_argument( @@ -578,7 +578,7 @@ def add_cli_args(parser: argparse.ArgumentParser): type=int, default=ServerArgs.max_total_tokens, help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. " - "This option is typically used for development and debugging purposes.", + "This option is typically used for development and debugging purposes.", ) parser.add_argument( "--chunked-prefill-size", @@ -1043,7 +1043,7 @@ def add_cli_args(parser: argparse.ArgumentParser): "--triton-attention-reduce-in-fp32", action="store_true", help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16." - "This only affects Triton attention kernels.", + "This only affects Triton attention kernels.", ) parser.add_argument( "--triton-attention-num-kv-splits", @@ -1056,8 +1056,8 @@ def add_cli_args(parser: argparse.ArgumentParser): type=int, default=ServerArgs.num_continuous_decode_steps, help="Run multiple continuous decoding steps to reduce scheduling overhead. " - "This can potentially increase throughput but may also increase time-to-first-token latency. " - "The default value is 1, meaning only run one decoding step at a time.", + "This can potentially increase throughput but may also increase time-to-first-token latency. " + "The default value is 1, meaning only run one decoding step at a time.", ) parser.add_argument( "--delete-ckpt-after-loading", @@ -1122,7 +1122,7 @@ def add_cli_args(parser: argparse.ArgumentParser): type=int, default=0, help="The number of shared_experts need to be replica to fuse with normal experts in deepseek v3/r1 " - "we use tp_size by default.", + "we use tp_size by default.", ) parser.add_argument( "--disable-shared-experts-fusion", @@ -1136,7 +1136,7 @@ def add_cli_args(parser: argparse.ArgumentParser): type=str, required=False, help="Specify custom warmup functions (csv) to run before server starts eg. --warmups=warmup_name1,warmup_name2 " - "will run the functions `warmup_name1` and `warmup_name2` specified in warmup.py before the server starts listening for requests", + "will run the functions `warmup_name1` and `warmup_name2` specified in warmup.py before the server starts listening for requests", ) # Debug tensor dumps From c81c56e1d8c3cb0ded80f9bde71e5dfb298c5a5a Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 9 Apr 2025 21:37:09 +0800 Subject: [PATCH 0221/1089] fmt --- python/sglang/srt/server_args.py | 70 ++++++++++++++++---------------- 1 file changed, 35 insertions(+), 35 deletions(-) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 6b833f443d7..5080cea8d04 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -420,8 +420,8 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.tokenizer_mode, choices=["auto", "slow"], help="Tokenizer mode. 'auto' will use the fast " - "tokenizer if available, and 'slow' will " - "always use the slow tokenizer.", + "tokenizer if available, and 'slow' will " + "always use the slow tokenizer.", ) parser.add_argument( "--skip-tokenizer-init", @@ -445,21 +445,21 @@ def add_cli_args(parser: argparse.ArgumentParser): "remote", ], help="The format of the model weights to load. " - '"auto" will try to load the weights in the safetensors format ' - "and fall back to the pytorch bin format if safetensors format " - "is not available. " - '"pt" will load the weights in the pytorch bin format. ' - '"safetensors" will load the weights in the safetensors format. ' - '"npcache" will load the weights in pytorch format and store ' - "a numpy cache to speed up the loading. " - '"dummy" will initialize the weights with random values, ' - "which is mainly for profiling." - '"gguf" will load the weights in the gguf format. ' - '"bitsandbytes" will load the weights using bitsandbytes ' - "quantization." - '"layered" loads weights layer by layer so that one can quantize a ' - "layer before loading another to make the peak memory envelope " - "smaller.", + '"auto" will try to load the weights in the safetensors format ' + "and fall back to the pytorch bin format if safetensors format " + "is not available. " + '"pt" will load the weights in the pytorch bin format. ' + '"safetensors" will load the weights in the safetensors format. ' + '"npcache" will load the weights in pytorch format and store ' + "a numpy cache to speed up the loading. " + '"dummy" will initialize the weights with random values, ' + "which is mainly for profiling." + '"gguf" will load the weights in the gguf format. ' + '"bitsandbytes" will load the weights using bitsandbytes ' + "quantization." + '"layered" loads weights layer by layer so that one can quantize a ' + "layer before loading another to make the peak memory envelope " + "smaller.", ) parser.add_argument( "--trust-remote-code", @@ -472,13 +472,13 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.dtype, choices=["auto", "half", "float16", "bfloat16", "float", "float32"], help="Data type for model weights and activations.\n\n" - '* "auto" will use FP16 precision for FP32 and FP16 models, and ' - "BF16 precision for BF16 models.\n" - '* "half" for FP16. Recommended for AWQ quantization.\n' - '* "float16" is the same as "half".\n' - '* "bfloat16" for a balance between precision and range.\n' - '* "float" is shorthand for FP32 precision.\n' - '* "float32" for FP32 precision.', + '* "auto" will use FP16 precision for FP32 and FP16 models, and ' + "BF16 precision for BF16 models.\n" + '* "half" for FP16. Recommended for AWQ quantization.\n' + '* "float16" is the same as "half".\n' + '* "bfloat16" for a balance between precision and range.\n' + '* "float" is shorthand for FP32 precision.\n' + '* "float32" for FP32 precision.', ) parser.add_argument( "--kv-cache-dtype", @@ -513,9 +513,9 @@ def add_cli_args(parser: argparse.ArgumentParser): type=nullable_str, default=None, help="Path to the JSON file containing the KV cache " - "scaling factors. This should generally be supplied, when " - "KV cache dtype is FP8. Otherwise, KV cache scaling factors " - "default to 1.0, which may cause accuracy issues. ", + "scaling factors. This should generally be supplied, when " + "KV cache dtype is FP8. Otherwise, KV cache scaling factors " + "default to 1.0, which may cause accuracy issues. ", ) parser.add_argument( "--context-length", @@ -557,8 +557,8 @@ def add_cli_args(parser: argparse.ArgumentParser): type=str, default=None, help="The specific model version to use. It can be a branch " - "name, a tag name, or a commit id. If unspecified, will use " - "the default version.", + "name, a tag name, or a commit id. If unspecified, will use " + "the default version.", ) # Memory and scheduling parser.add_argument( @@ -578,7 +578,7 @@ def add_cli_args(parser: argparse.ArgumentParser): type=int, default=ServerArgs.max_total_tokens, help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. " - "This option is typically used for development and debugging purposes.", + "This option is typically used for development and debugging purposes.", ) parser.add_argument( "--chunked-prefill-size", @@ -1043,7 +1043,7 @@ def add_cli_args(parser: argparse.ArgumentParser): "--triton-attention-reduce-in-fp32", action="store_true", help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16." - "This only affects Triton attention kernels.", + "This only affects Triton attention kernels.", ) parser.add_argument( "--triton-attention-num-kv-splits", @@ -1056,8 +1056,8 @@ def add_cli_args(parser: argparse.ArgumentParser): type=int, default=ServerArgs.num_continuous_decode_steps, help="Run multiple continuous decoding steps to reduce scheduling overhead. " - "This can potentially increase throughput but may also increase time-to-first-token latency. " - "The default value is 1, meaning only run one decoding step at a time.", + "This can potentially increase throughput but may also increase time-to-first-token latency. " + "The default value is 1, meaning only run one decoding step at a time.", ) parser.add_argument( "--delete-ckpt-after-loading", @@ -1122,7 +1122,7 @@ def add_cli_args(parser: argparse.ArgumentParser): type=int, default=0, help="The number of shared_experts need to be replica to fuse with normal experts in deepseek v3/r1 " - "we use tp_size by default.", + "we use tp_size by default.", ) parser.add_argument( "--disable-shared-experts-fusion", @@ -1136,7 +1136,7 @@ def add_cli_args(parser: argparse.ArgumentParser): type=str, required=False, help="Specify custom warmup functions (csv) to run before server starts eg. --warmups=warmup_name1,warmup_name2 " - "will run the functions `warmup_name1` and `warmup_name2` specified in warmup.py before the server starts listening for requests", + "will run the functions `warmup_name1` and `warmup_name2` specified in warmup.py before the server starts listening for requests", ) # Debug tensor dumps From 72e47a19ec54edcb6dd7c143b246382e34e890cb Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 15:04:49 +0800 Subject: [PATCH 0222/1089] more --- test/srt/run_suite.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 5560cea6758..4accef613cd 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -83,6 +83,7 @@ class TestFile: TestFile("models/lora/test_lora_tp.py", 300), TestFile("test_data_parallelism.py", 90), TestFile("test_dp_attention.py", 90), + TestFile("test_expert_distribution.py", 100), TestFile("test_mla_tp.py", 420), TestFile("test_moe_ep.py", 220), TestFile("test_patch_torch.py", 30), From 1dbde71fb6a747b566c70a441bad2bc8736b655e Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 15:05:17 +0800 Subject: [PATCH 0223/1089] more --- test/srt/test_expert_distribution.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/test/srt/test_expert_distribution.py b/test/srt/test_expert_distribution.py index 50acda9077a..29429e63c2f 100755 --- a/test/srt/test_expert_distribution.py +++ b/test/srt/test_expert_distribution.py @@ -3,7 +3,6 @@ import requests import torch - from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST, @@ -18,17 +17,15 @@ class TestExpertDistribution(CustomTestCase): def test_expert_distribution_record(self): # TODO: Add tests for DeepEP gatherer (currently our CI cannot run that) for info in [ - dict( - model_path="deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct", - mode_detail=False, - ), - dict(model_path="Qwen/Qwen1.5-MoE-A2.7B", mode_detail=False), + dict(model_path="deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"), + dict(model_path="Qwen/Qwen1.5-MoE-A2.7B"), + dict(model_path="Qwen/Qwen1.5-MoE-A2.7B", tp_size=1), dict(model_path="Qwen/Qwen1.5-MoE-A2.7B", mode_detail=True), ]: with self.subTest(info=info): self._execute_core(**info) - def _execute_core(self, model_path: str, mode_detail: bool): + def _execute_core(self, model_path: str, mode_detail: bool = False, tp_size: int = 1): """Test expert distribution record endpoints""" os.environ["SGLANG_EXPERT_DISTRIBUTION_RECORDER_DETAIL"] = ( "1" if mode_detail else "0" From e07a3235e1d6b99f65e5b204a97d44a502754c68 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 15:05:26 +0800 Subject: [PATCH 0224/1089] more --- test/srt/test_expert_distribution.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/srt/test_expert_distribution.py b/test/srt/test_expert_distribution.py index 29429e63c2f..ed65998d2cf 100755 --- a/test/srt/test_expert_distribution.py +++ b/test/srt/test_expert_distribution.py @@ -36,6 +36,8 @@ def _execute_core(self, model_path: str, mode_detail: bool = False, tp_size: int timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, other_args=[ "--trust-remote-code", + "--tp-size", + str(tp_size), ], ) From 96d91837fe0fd995d716a7469c53c1de7ffc3db4 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 15:06:02 +0800 Subject: [PATCH 0225/1089] more --- python/sglang/srt/managers/expert_location.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index c533784f2cd..9f571fc5d1e 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -1,7 +1,6 @@ from dataclasses import dataclass import torch - from sglang.srt.distributed import get_tensor_model_parallel_world_size @@ -23,7 +22,7 @@ def from_model(model): def init_new(num_layers: int, num_logical_experts: int): # TODO handle more complex cases like duplicating experts on different GPUs num_local_physical_experts = ( - num_logical_experts // get_tensor_model_parallel_world_size() + num_logical_experts // get_tensor_model_parallel_world_size() ) num_physical_experts = num_logical_experts @@ -41,6 +40,9 @@ def init_new(num_layers: int, num_logical_experts: int): def _init_dummy(): return ExpertLocationMetadata.init_new(num_layers=1, num_logical_experts=1) + def local_physical_to_global_physical(self): + return TODO + def _create_vanilla_physical_to_logical_map(num_layers: int, num_physical_experts: int): return torch.arange(0, num_physical_experts).repeat(num_layers, 1) From 3c6d511c8e0fd0c62a5675ddfe328ddf2a943342 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 15:06:31 +0800 Subject: [PATCH 0226/1089] more --- python/sglang/srt/managers/expert_location.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 9f571fc5d1e..9edd4e242b8 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -22,7 +22,7 @@ def from_model(model): def init_new(num_layers: int, num_logical_experts: int): # TODO handle more complex cases like duplicating experts on different GPUs num_local_physical_experts = ( - num_logical_experts // get_tensor_model_parallel_world_size() + num_logical_experts // get_tensor_model_parallel_world_size() ) num_physical_experts = num_logical_experts @@ -40,8 +40,8 @@ def init_new(num_layers: int, num_logical_experts: int): def _init_dummy(): return ExpertLocationMetadata.init_new(num_layers=1, num_logical_experts=1) - def local_physical_to_global_physical(self): - return TODO + def local_physical_to_global_physical(self, rank: int, local_physical_expert_index: int): + return self.num_local_physical_experts * rank + local_physical_expert_index def _create_vanilla_physical_to_logical_map(num_layers: int, num_physical_experts: int): From 5036994d75513a68c1ff42a02fae1bfc3a4a6418 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 15:07:07 +0800 Subject: [PATCH 0227/1089] more --- python/sglang/srt/managers/expert_distribution.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 3356228eea4..b0712aad0f8 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -191,8 +191,9 @@ def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor): 0 ] * self._expert_location_metadata.num_local_physical_experts for token_record in topk_ids_list: - for expert_idx in token_record: - num_recv_tokens_per_expert_list[expert_idx] += 1 + for global_physical_expert_idx in token_record: + local_physical_expert_idx = TODO + num_recv_tokens_per_expert_list[local_physical_expert_idx] += 1 self._on_layer_data(layer_idx, num_recv_tokens_per_expert_list) From 15cdeea585411293f0ab6647493dd0d811294cb9 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 15:07:32 +0800 Subject: [PATCH 0228/1089] more --- python/sglang/srt/managers/expert_distribution.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index b0712aad0f8..b78da1003ad 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -192,7 +192,7 @@ def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor): ] * self._expert_location_metadata.num_local_physical_experts for token_record in topk_ids_list: for global_physical_expert_idx in token_record: - local_physical_expert_idx = TODO + local_physical_expert_idx = self._expert_location_metadata.global_physical_to_local_physical(global_physical_expert_idx) num_recv_tokens_per_expert_list[local_physical_expert_idx] += 1 self._on_layer_data(layer_idx, num_recv_tokens_per_expert_list) From 8fe42ea8c8e1dec05ae3dd3d795c54eecaec8345 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 15:08:01 +0800 Subject: [PATCH 0229/1089] more --- python/sglang/srt/managers/expert_location.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 9edd4e242b8..6dee85cf58d 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -40,8 +40,8 @@ def init_new(num_layers: int, num_logical_experts: int): def _init_dummy(): return ExpertLocationMetadata.init_new(num_layers=1, num_logical_experts=1) - def local_physical_to_global_physical(self, rank: int, local_physical_expert_index: int): - return self.num_local_physical_experts * rank + local_physical_expert_index + def global_physical_to_local_physical(self, global_physical_expert_index: int): + return global_physical_expert_index % self.num_local_physical_experts def _create_vanilla_physical_to_logical_map(num_layers: int, num_physical_experts: int): From 0d17f33101f1dcbfcfeb3b692b3a40c72dcd8440 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 15:08:18 +0800 Subject: [PATCH 0230/1089] more --- python/sglang/srt/managers/expert_location.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 6dee85cf58d..ebae0148d94 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -40,6 +40,9 @@ def init_new(num_layers: int, num_logical_experts: int): def _init_dummy(): return ExpertLocationMetadata.init_new(num_layers=1, num_logical_experts=1) + def local_physical_to_global_physical(self, rank: int, local_physical_expert_index: int): + return self.num_local_physical_experts * rank + local_physical_expert_index + def global_physical_to_local_physical(self, global_physical_expert_index: int): return global_physical_expert_index % self.num_local_physical_experts From 2e0ca4fbff55246806394b6e6a860edfedcf77a9 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 15:09:03 +0800 Subject: [PATCH 0231/1089] more --- python/sglang/srt/managers/expert_distribution.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index b78da1003ad..5f5db57e6a3 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -5,7 +5,6 @@ from typing import Any, List, Optional, Type import torch - from sglang.srt.managers.expert_location import ExpertLocationMetadata from sglang.srt.server_args import ServerArgs from sglang.srt.utils import Withable, get_bool_env_var @@ -188,11 +187,12 @@ def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor): torch.cuda.synchronize() num_recv_tokens_per_expert_list = [ - 0 - ] * self._expert_location_metadata.num_local_physical_experts + 0 + ] * self._expert_location_metadata.num_local_physical_experts for token_record in topk_ids_list: for global_physical_expert_idx in token_record: - local_physical_expert_idx = self._expert_location_metadata.global_physical_to_local_physical(global_physical_expert_idx) + local_physical_expert_idx = self._expert_location_metadata.global_physical_to_local_physical( + global_physical_expert_idx) num_recv_tokens_per_expert_list[local_physical_expert_idx] += 1 self._on_layer_data(layer_idx, num_recv_tokens_per_expert_list) @@ -329,10 +329,9 @@ def postprocess_dumps( for local_physical_expert_index in range( expert_location_metadata.num_local_physical_experts ): - global_physical_expert_index = ( - expert_location_metadata.num_local_physical_experts - * physical_dump["rank"] - + local_physical_expert_index + global_physical_expert_index = expert_location_metadata.local_physical_to_global_physical( + rank=physical_dump["rank"], + local_physical_expert_index=local_physical_expert_index ) logical_expert_index = ( expert_location_metadata.physical_to_logical_map[ From fe900082218499dda9e29ace10de195206e15d37 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 15:09:49 +0800 Subject: [PATCH 0232/1089] more --- python/sglang/srt/entrypoints/engine.py | 3 --- python/sglang/srt/managers/tokenizer_manager.py | 7 +++---- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index a5e9f6bd43a..33aab232fe3 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -596,7 +596,4 @@ def _launch_subprocesses( # Assume all schedulers have the same scheduler_info scheduler_info = scheduler_infos[0] tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"] - tokenizer_manager.expert_location_metadata = scheduler_info[ - "expert_location_metadata" - ] return tokenizer_manager, scheduler_info diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 7266a3e26b9..38258900247 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -45,7 +45,6 @@ import zmq import zmq.asyncio from fastapi import BackgroundTasks - from sglang.srt.aio_rwlock import RWLock from sglang.srt.configs.model_config import ModelConfig from sglang.srt.disaggregation.conn import KVBootstrapServer @@ -230,7 +229,7 @@ def __init__( # Set after scheduler is initialized self.max_req_input_len = None - self.expert_location_metadata = None + self.expert_location_metadata = TODO # Metrics if self.enable_metrics: @@ -969,8 +968,8 @@ def _handle_batch_output( elif isinstance(recv_obj, BatchTokenIDOut): if self.server_args.stream_output and state.obj.stream: output_token_ids = recv_obj.output_ids[i][ - state.last_output_offset : - ] + state.last_output_offset: + ] state.last_output_offset = len(recv_obj.output_ids[i]) else: output_token_ids = recv_obj.output_ids[i] From 2d533bbe39d8f5ed2e3da90ebaf11fa924a9b6ff Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 15:10:15 +0800 Subject: [PATCH 0233/1089] more --- python/sglang/srt/managers/data_parallel_controller.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index 6eef426c640..fb0264a6ea9 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -100,7 +100,6 @@ def __init__(self, server_args: ServerArgs, port_args: PortArgs) -> None: ) self.max_req_input_len = None - self.expert_location_metadata = None def launch_dp_schedulers(self, server_args, port_args): base_gpu_id = 0 @@ -220,7 +219,6 @@ def launch_tensor_parallel_group( self.max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"] self.max_req_input_len = scheduler_info[0]["max_req_input_len"] - self.expert_location_metadata = scheduler_info[0]["expert_location_metadata"] def round_robin_scheduler(self, req): self.workers[self.round_robin_counter].send_pyobj(req) @@ -267,7 +265,6 @@ def run_data_parallel_controller_process( "status": "ready", "max_total_num_tokens": controller.max_total_num_tokens, "max_req_input_len": controller.max_req_input_len, - "expert_location_metadata": controller.expert_location_metadata, } ) if server_args.node_rank == 0: From b9f368f9912768ce857eb68faedb671bb1b8d65d Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 15:10:49 +0800 Subject: [PATCH 0234/1089] more --- python/sglang/srt/managers/scheduler.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 519fa0a3c3f..4554b4deac8 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -287,7 +287,6 @@ def __init__( self.random_seed, self.device, worker_global_server_args_dict, - self.expert_location_metadata, _, _, _, @@ -2052,7 +2051,6 @@ def run_scheduler_process( "status": "ready", "max_total_num_tokens": scheduler.max_total_num_tokens, "max_req_input_len": scheduler.max_req_input_len, - "expert_location_metadata": scheduler.expert_location_metadata, } ) disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode From 12a0270215789a7995252fe93c29a46f7aa53766 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 15:11:38 +0800 Subject: [PATCH 0235/1089] fmt --- .../srt/managers/expert_distribution.py | 20 ++++++++++++------- python/sglang/srt/managers/expert_location.py | 5 ++++- .../sglang/srt/managers/tokenizer_manager.py | 5 +++-- python/sglang/srt/managers/tp_worker.py | 1 - .../sglang/srt/model_executor/model_runner.py | 2 +- test/srt/test_expert_distribution.py | 5 ++++- 6 files changed, 25 insertions(+), 13 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 5f5db57e6a3..28b71f059de 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -5,6 +5,7 @@ from typing import Any, List, Optional, Type import torch + from sglang.srt.managers.expert_location import ExpertLocationMetadata from sglang.srt.server_args import ServerArgs from sglang.srt.utils import Withable, get_bool_env_var @@ -187,12 +188,15 @@ def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor): torch.cuda.synchronize() num_recv_tokens_per_expert_list = [ - 0 - ] * self._expert_location_metadata.num_local_physical_experts + 0 + ] * self._expert_location_metadata.num_local_physical_experts for token_record in topk_ids_list: for global_physical_expert_idx in token_record: - local_physical_expert_idx = self._expert_location_metadata.global_physical_to_local_physical( - global_physical_expert_idx) + local_physical_expert_idx = ( + self._expert_location_metadata.global_physical_to_local_physical( + global_physical_expert_idx + ) + ) num_recv_tokens_per_expert_list[local_physical_expert_idx] += 1 self._on_layer_data(layer_idx, num_recv_tokens_per_expert_list) @@ -329,9 +333,11 @@ def postprocess_dumps( for local_physical_expert_index in range( expert_location_metadata.num_local_physical_experts ): - global_physical_expert_index = expert_location_metadata.local_physical_to_global_physical( - rank=physical_dump["rank"], - local_physical_expert_index=local_physical_expert_index + global_physical_expert_index = ( + expert_location_metadata.local_physical_to_global_physical( + rank=physical_dump["rank"], + local_physical_expert_index=local_physical_expert_index, + ) ) logical_expert_index = ( expert_location_metadata.physical_to_logical_map[ diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index ebae0148d94..674eb7434a3 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -1,6 +1,7 @@ from dataclasses import dataclass import torch + from sglang.srt.distributed import get_tensor_model_parallel_world_size @@ -40,7 +41,9 @@ def init_new(num_layers: int, num_logical_experts: int): def _init_dummy(): return ExpertLocationMetadata.init_new(num_layers=1, num_logical_experts=1) - def local_physical_to_global_physical(self, rank: int, local_physical_expert_index: int): + def local_physical_to_global_physical( + self, rank: int, local_physical_expert_index: int + ): return self.num_local_physical_experts * rank + local_physical_expert_index def global_physical_to_local_physical(self, global_physical_expert_index: int): diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 38258900247..a5d159abf11 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -45,6 +45,7 @@ import zmq import zmq.asyncio from fastapi import BackgroundTasks + from sglang.srt.aio_rwlock import RWLock from sglang.srt.configs.model_config import ModelConfig from sglang.srt.disaggregation.conn import KVBootstrapServer @@ -968,8 +969,8 @@ def _handle_batch_output( elif isinstance(recv_obj, BatchTokenIDOut): if self.server_args.stream_output and state.obj.stream: output_token_ids = recv_obj.output_ids[i][ - state.last_output_offset: - ] + state.last_output_offset : + ] state.last_output_offset = len(recv_obj.output_ids[i]) else: output_token_ids = recv_obj.output_ids[i] diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index f53941d056d..174f2e53321 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -145,7 +145,6 @@ def get_worker_info(self): self.random_seed, self.device, global_server_args_dict, - self.model_runner.expert_location_metadata, self.model_runner.req_to_token_pool.size, self.model_runner.req_to_token_pool.max_context_len, self.model_runner.token_to_kv_pool.size, diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index c9288e18b04..201c1c77d5e 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -191,7 +191,7 @@ def initialize(self, min_per_gpu_memory: float): self.sampler = Sampler() self.load_model() - self.expert_location_metadata = ExpertLocationMetadata.from_model(self.model) + self.expert_location_metadata = TODO expert_distribution_recorder.initialize( server_args, self.expert_location_metadata, diff --git a/test/srt/test_expert_distribution.py b/test/srt/test_expert_distribution.py index ed65998d2cf..d6caca23bf5 100755 --- a/test/srt/test_expert_distribution.py +++ b/test/srt/test_expert_distribution.py @@ -3,6 +3,7 @@ import requests import torch + from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST, @@ -25,7 +26,9 @@ def test_expert_distribution_record(self): with self.subTest(info=info): self._execute_core(**info) - def _execute_core(self, model_path: str, mode_detail: bool = False, tp_size: int = 1): + def _execute_core( + self, model_path: str, mode_detail: bool = False, tp_size: int = 1 + ): """Test expert distribution record endpoints""" os.environ["SGLANG_EXPERT_DISTRIBUTION_RECORDER_DETAIL"] = ( "1" if mode_detail else "0" From dc09c51f112751ca8bf2ef1f8e5da82d3bf80c0c Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 15:13:29 +0800 Subject: [PATCH 0236/1089] more --- python/sglang/srt/configs/model_config.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index d17add76920..9a644703722 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -24,6 +24,7 @@ from sglang.srt.hf_transformers_utils import get_config, get_context_length from sglang.srt.layers.quantization import QUANTIZATION_METHODS +from sglang.srt.server_args import ServerArgs from sglang.srt.utils import get_bool_env_var, is_hip logger = logging.getLogger(__name__) @@ -171,6 +172,19 @@ def __init__( self.hf_eos_token_id = self.get_hf_eos_token_id() self.image_token_id = getattr(self.hf_config, "image_token_id", None) + @staticmethod + def from_server_args(server_args: ServerArgs): + return ModelConfig( + server_args.model_path, + trust_remote_code=server_args.trust_remote_code, + revision=server_args.revision, + context_length=server_args.context_length, + model_override_args=server_args.json_model_override_args, + is_embedding=server_args.is_embedding, + dtype=server_args.dtype, + quantization=server_args.quantization, + ) + # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289 def get_total_num_kv_heads(self) -> int: """Returns the total number of KV heads.""" From 46d4e3e2c61496c682e959a9f4076aa5f4aeb497 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 15:13:56 +0800 Subject: [PATCH 0237/1089] more --- python/sglang/bench_one_batch.py | 11 +---------- python/sglang/srt/managers/scheduler.py | 11 +---------- 2 files changed, 2 insertions(+), 20 deletions(-) diff --git a/python/sglang/bench_one_batch.py b/python/sglang/bench_one_batch.py index a1b2d4723b8..a34cd9af111 100644 --- a/python/sglang/bench_one_batch.py +++ b/python/sglang/bench_one_batch.py @@ -129,16 +129,7 @@ def load_model(server_args, port_args, tp_rank): suppress_other_loggers() rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None - model_config = ModelConfig( - server_args.model_path, - trust_remote_code=server_args.trust_remote_code, - revision=server_args.revision, - context_length=server_args.context_length, - model_override_args=server_args.json_model_override_args, - is_embedding=server_args.is_embedding, - dtype=server_args.dtype, - quantization=server_args.quantization, - ) + model_config = ModelConfig.from_server_args(server_args) model_runner = ModelRunner( model_config=model_config, mem_fraction_static=server_args.mem_fraction_static, diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 383cd680945..15ac5e24230 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -430,16 +430,7 @@ def __init__( def init_tokenizer(self): server_args = self.server_args - self.model_config = ModelConfig( - server_args.model_path, - trust_remote_code=server_args.trust_remote_code, - revision=server_args.revision, - context_length=server_args.context_length, - model_override_args=server_args.json_model_override_args, - is_embedding=server_args.is_embedding, - dtype=server_args.dtype, - quantization=server_args.quantization, - ) + self.model_config = ModelConfig.from_server_args(server_args) self.is_generation = self.model_config.is_generation if server_args.skip_tokenizer_init: From 0e579a369e41862be9bd2af866eca5485663878f Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 15:14:41 +0800 Subject: [PATCH 0238/1089] more --- python/sglang/srt/managers/scheduler.py | 25 +++++++++---------- .../sglang/srt/managers/tokenizer_manager.py | 16 +++--------- test/srt/test_gptqmodel_dynamic.py | 11 +------- 3 files changed, 16 insertions(+), 36 deletions(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 15ac5e24230..8e139a67677 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -32,8 +32,6 @@ import setproctitle import torch import zmq -from torch.distributed import barrier - from sglang.global_config import global_config from sglang.srt.configs.model_config import ModelConfig from sglang.srt.constrained.base_grammar_backend import create_grammar_backend @@ -131,6 +129,7 @@ suppress_other_loggers, ) from sglang.utils import TypeBasedDispatcher, get_exception_traceback +from torch.distributed import barrier expert_distribution_recorder = ExpertDistributionRecorder() @@ -371,8 +370,8 @@ def __init__( 1.0, ) self.new_token_ratio_decay = ( - self.init_new_token_ratio - self.min_new_token_ratio - ) / global_config.default_new_token_ratio_decay_steps + self.init_new_token_ratio - self.min_new_token_ratio + ) / global_config.default_new_token_ratio_decay_steps self.new_token_ratio = self.init_new_token_ratio # Init watchdog thread @@ -1251,10 +1250,10 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: if ( self.lora_paths and len( - lora_set - | set([req.lora_path for req in adder.can_run_list]) - | set([req.lora_path]) - ) + lora_set + | set([req.lora_path for req in adder.can_run_list]) + | set([req.lora_path]) + ) > self.max_loras_per_batch ): self.running_batch.batch_is_full = True @@ -1279,9 +1278,9 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: self.running_batch.batch_is_full = len( adder.can_run_list ) > 0 or ( - self.running_batch is not None - and not self.running_batch.is_empty() - ) + self.running_batch is not None + and not self.running_batch.is_empty() + ) else: self.running_batch.batch_is_full = True break @@ -1513,8 +1512,8 @@ def prepare_dp_attn_batch_raw( # We should have at least 1 token for sample in every case. max(extend_len - logprob_start_len, 1) for logprob_start_len, extend_len in zip( - local_batch.extend_logprob_start_lens, local_batch.extend_lens - ) + local_batch.extend_logprob_start_lens, local_batch.extend_lens + ) ] ) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 33afffbd6de..9de7a4eec18 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -45,7 +45,6 @@ import zmq import zmq.asyncio from fastapi import BackgroundTasks - from sglang.srt.aio_rwlock import RWLock from sglang.srt.configs.model_config import ModelConfig from sglang.srt.disaggregation.conn import KVBootstrapServer @@ -156,16 +155,7 @@ def __init__( # Read model args self.model_path = server_args.model_path self.served_model_name = server_args.served_model_name - self.model_config = ModelConfig( - server_args.model_path, - trust_remote_code=server_args.trust_remote_code, - revision=server_args.revision, - context_length=server_args.context_length, - model_override_args=server_args.json_model_override_args, - is_embedding=server_args.is_embedding, - dtype=server_args.dtype, - quantization=server_args.quantization, - ) + self.model_config = ModelConfig.from_server_args(server_args) self.is_generation = self.model_config.is_generation self.is_image_gen = self.model_config.is_image_gen @@ -959,8 +949,8 @@ def _handle_batch_output( elif isinstance(recv_obj, BatchTokenIDOut): if self.server_args.stream_output and state.obj.stream: output_token_ids = recv_obj.output_ids[i][ - state.last_output_offset : - ] + state.last_output_offset: + ] state.last_output_offset = len(recv_obj.output_ids[i]) else: output_token_ids = recv_obj.output_ids[i] diff --git a/test/srt/test_gptqmodel_dynamic.py b/test/srt/test_gptqmodel_dynamic.py index 54dbaf49663..27ccd9a4b2f 100644 --- a/test/srt/test_gptqmodel_dynamic.py +++ b/test/srt/test_gptqmodel_dynamic.py @@ -43,16 +43,7 @@ def check_quant_method(model_path: str, use_marlin_kernel: bool): pass server_args = ServerArgs(model_path=model_path, dtype=torch.float16) - model_config = ModelConfig( - server_args.model_path, - trust_remote_code=server_args.trust_remote_code, - revision=server_args.revision, - context_length=server_args.context_length, - model_override_args=server_args.json_model_override_args, - is_embedding=server_args.is_embedding, - dtype=server_args.dtype, - quantization=server_args.quantization, - ) + model_config = ModelConfig.from_server_args(server_args) load_config = LoadConfig() device_config = DeviceConfig("cuda") From 2fd3c19bb8f83dd43291f88db6e8e1ad24d91491 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 15:15:37 +0800 Subject: [PATCH 0239/1089] more --- python/sglang/srt/configs/model_config.py | 4 ++-- python/sglang/srt/managers/tp_worker.py | 12 +++--------- 2 files changed, 5 insertions(+), 11 deletions(-) diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 9a644703722..ac26cb171fd 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -173,9 +173,9 @@ def __init__( self.image_token_id = getattr(self.hf_config, "image_token_id", None) @staticmethod - def from_server_args(server_args: ServerArgs): + def from_server_args(server_args: ServerArgs, model_path: str=None): return ModelConfig( - server_args.model_path, + model_path=model_path or server_args.model_path, trust_remote_code=server_args.trust_remote_code, revision=server_args.revision, context_length=server_args.context_length, diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 174f2e53321..c8388da75ed 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -57,19 +57,13 @@ def __init__( self.tp_rank = tp_rank # Init model and tokenizer - self.model_config = ModelConfig( - ( + self.model_config = ModelConfig.from_server_args( + server_args, + model_path=( server_args.model_path if not is_draft_worker else server_args.speculative_draft_model_path ), - trust_remote_code=server_args.trust_remote_code, - revision=server_args.revision, - context_length=server_args.context_length, - model_override_args=server_args.json_model_override_args, - is_embedding=server_args.is_embedding, - dtype=server_args.dtype, - quantization=server_args.quantization, ) self.model_runner = ModelRunner( model_config=self.model_config, From d053592f36bdb142757f91f273aeb93bc62b3fc6 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 15:16:04 +0800 Subject: [PATCH 0240/1089] fmt --- python/sglang/srt/configs/model_config.py | 2 +- python/sglang/srt/managers/scheduler.py | 25 ++++++++++--------- .../sglang/srt/managers/tokenizer_manager.py | 5 ++-- 3 files changed, 17 insertions(+), 15 deletions(-) diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index ac26cb171fd..6ec19fd454c 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -173,7 +173,7 @@ def __init__( self.image_token_id = getattr(self.hf_config, "image_token_id", None) @staticmethod - def from_server_args(server_args: ServerArgs, model_path: str=None): + def from_server_args(server_args: ServerArgs, model_path: str = None): return ModelConfig( model_path=model_path or server_args.model_path, trust_remote_code=server_args.trust_remote_code, diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 8e139a67677..15ac5e24230 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -32,6 +32,8 @@ import setproctitle import torch import zmq +from torch.distributed import barrier + from sglang.global_config import global_config from sglang.srt.configs.model_config import ModelConfig from sglang.srt.constrained.base_grammar_backend import create_grammar_backend @@ -129,7 +131,6 @@ suppress_other_loggers, ) from sglang.utils import TypeBasedDispatcher, get_exception_traceback -from torch.distributed import barrier expert_distribution_recorder = ExpertDistributionRecorder() @@ -370,8 +371,8 @@ def __init__( 1.0, ) self.new_token_ratio_decay = ( - self.init_new_token_ratio - self.min_new_token_ratio - ) / global_config.default_new_token_ratio_decay_steps + self.init_new_token_ratio - self.min_new_token_ratio + ) / global_config.default_new_token_ratio_decay_steps self.new_token_ratio = self.init_new_token_ratio # Init watchdog thread @@ -1250,10 +1251,10 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: if ( self.lora_paths and len( - lora_set - | set([req.lora_path for req in adder.can_run_list]) - | set([req.lora_path]) - ) + lora_set + | set([req.lora_path for req in adder.can_run_list]) + | set([req.lora_path]) + ) > self.max_loras_per_batch ): self.running_batch.batch_is_full = True @@ -1278,9 +1279,9 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: self.running_batch.batch_is_full = len( adder.can_run_list ) > 0 or ( - self.running_batch is not None - and not self.running_batch.is_empty() - ) + self.running_batch is not None + and not self.running_batch.is_empty() + ) else: self.running_batch.batch_is_full = True break @@ -1512,8 +1513,8 @@ def prepare_dp_attn_batch_raw( # We should have at least 1 token for sample in every case. max(extend_len - logprob_start_len, 1) for logprob_start_len, extend_len in zip( - local_batch.extend_logprob_start_lens, local_batch.extend_lens - ) + local_batch.extend_logprob_start_lens, local_batch.extend_lens + ) ] ) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 9de7a4eec18..dbb5b431006 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -45,6 +45,7 @@ import zmq import zmq.asyncio from fastapi import BackgroundTasks + from sglang.srt.aio_rwlock import RWLock from sglang.srt.configs.model_config import ModelConfig from sglang.srt.disaggregation.conn import KVBootstrapServer @@ -949,8 +950,8 @@ def _handle_batch_output( elif isinstance(recv_obj, BatchTokenIDOut): if self.server_args.stream_output and state.obj.stream: output_token_ids = recv_obj.output_ids[i][ - state.last_output_offset: - ] + state.last_output_offset : + ] state.last_output_offset = len(recv_obj.output_ids[i]) else: output_token_ids = recv_obj.output_ids[i] From 8fc36d8e8231b8721f0d40d5098a654e1ede019a Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 15:18:27 +0800 Subject: [PATCH 0241/1089] more --- python/sglang/srt/entrypoints/engine.py | 4 +++- python/sglang/srt/managers/tokenizer_manager.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 33aab232fe3..d77f1168286 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -495,6 +495,8 @@ def _launch_subprocesses( server_args.model_path, server_args.tokenizer_path ) + expert_location_metadata = TODO + scheduler_procs = [] if server_args.dp_size == 1: # Launch tensor parallel scheduler processes @@ -565,7 +567,7 @@ def _launch_subprocesses( detoken_proc.start() # Launch tokenizer process - tokenizer_manager = TokenizerManager(server_args, port_args) + tokenizer_manager = TokenizerManager(server_args, port_args, expert_location_metadata) if server_args.chat_template: load_chat_template_for_openai_api( tokenizer_manager, server_args.chat_template, server_args.model_path diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index a6a400e1187..93a9507bf19 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -52,6 +52,7 @@ from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.managers import expert_distribution +from sglang.srt.managers.expert_location import ExpertLocationMetadata from sglang.srt.managers.io_struct import ( AbortReq, BatchEmbeddingOut, @@ -138,6 +139,7 @@ def __init__( self, server_args: ServerArgs, port_args: PortArgs, + expert_location_metadata: ExpertLocationMetadata, ): # Parse args self.server_args = server_args @@ -221,7 +223,7 @@ def __init__( # Set after scheduler is initialized self.max_req_input_len = None - self.expert_location_metadata = TODO + self.expert_location_metadata = expert_location_metadata # Metrics if self.enable_metrics: From c4ebfae5f2716623440083c77809d6e4aa6673e9 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 15:18:46 +0800 Subject: [PATCH 0242/1089] more --- python/sglang/srt/entrypoints/engine.py | 2 +- python/sglang/srt/managers/tokenizer_manager.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index d77f1168286..ca53d3d70f2 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -495,7 +495,7 @@ def _launch_subprocesses( server_args.model_path, server_args.tokenizer_path ) - expert_location_metadata = TODO + expert_location_metadata = ExpertLocationMetadata.from_model(TODO) scheduler_procs = [] if server_args.dp_size == 1: diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 93a9507bf19..5274315a216 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -45,7 +45,6 @@ import zmq import zmq.asyncio from fastapi import BackgroundTasks - from sglang.srt.aio_rwlock import RWLock from sglang.srt.configs.model_config import ModelConfig from sglang.srt.disaggregation.conn import KVBootstrapServer @@ -962,8 +961,8 @@ def _handle_batch_output( elif isinstance(recv_obj, BatchTokenIDOut): if self.server_args.stream_output and state.obj.stream: output_token_ids = recv_obj.output_ids[i][ - state.last_output_offset : - ] + state.last_output_offset: + ] state.last_output_offset = len(recv_obj.output_ids[i]) else: output_token_ids = recv_obj.output_ids[i] From a787fbd9b1769e178bca541861edb3be8322bb6c Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 15:19:09 +0800 Subject: [PATCH 0243/1089] more --- python/sglang/srt/entrypoints/engine.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index ca53d3d70f2..857b283426e 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -30,6 +30,8 @@ import zmq import zmq.asyncio from PIL.Image import Image +from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.managers.expert_location import ExpertLocationMetadata # Fix a bug of Python threading setattr(threading, "_register_atexit", lambda *args, **kwargs: None) @@ -495,7 +497,8 @@ def _launch_subprocesses( server_args.model_path, server_args.tokenizer_path ) - expert_location_metadata = ExpertLocationMetadata.from_model(TODO) + model_config = ModelConfig.from_server_args(server_args) + expert_location_metadata = ExpertLocationMetadata.from_model_config(model_config) scheduler_procs = [] if server_args.dp_size == 1: From 5e825401477ce24b7c709494f4ace42f95efbccf Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 15:19:34 +0800 Subject: [PATCH 0244/1089] more --- python/sglang/srt/managers/expert_location.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 674eb7434a3..e1a859fce94 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -2,6 +2,7 @@ import torch +from sglang.srt.configs.model_config import ModelConfig from sglang.srt.distributed import get_tensor_model_parallel_world_size @@ -14,9 +15,9 @@ class ExpertLocationMetadata: physical_to_logical_map: torch.Tensor # (layers, num_physical_experts) @staticmethod - def from_model(model): - if hasattr(model, "get_expert_location_metadata"): - return model.get_expert_location_metadata() + def from_model_config(model_config: ModelConfig): + if hasattr(model_config, "get_expert_location_metadata"): + return model_config.get_expert_location_metadata() return ExpertLocationMetadata._init_dummy() @staticmethod From e4373eecb3ffa71d50e92eec8def94316ce9881e Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 15:20:16 +0800 Subject: [PATCH 0245/1089] more --- python/sglang/srt/managers/expert_location.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index e1a859fce94..460879ffc19 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -4,6 +4,7 @@ from sglang.srt.configs.model_config import ModelConfig from sglang.srt.distributed import get_tensor_model_parallel_world_size +from sglang.srt.model_loader import get_model_architecture @dataclass @@ -16,6 +17,7 @@ class ExpertLocationMetadata: @staticmethod def from_model_config(model_config: ModelConfig): + model_class, _ = get_model_architecture(model_config) if hasattr(model_config, "get_expert_location_metadata"): return model_config.get_expert_location_metadata() return ExpertLocationMetadata._init_dummy() From 0adb6df75f3b36a30f23862d160efca685d36296 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 15:20:31 +0800 Subject: [PATCH 0246/1089] more --- python/sglang/srt/managers/expert_location.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 460879ffc19..8704abd64a1 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -18,8 +18,8 @@ class ExpertLocationMetadata: @staticmethod def from_model_config(model_config: ModelConfig): model_class, _ = get_model_architecture(model_config) - if hasattr(model_config, "get_expert_location_metadata"): - return model_config.get_expert_location_metadata() + if hasattr(model_class, "get_expert_location_metadata"): + return model_class.get_expert_location_metadata() return ExpertLocationMetadata._init_dummy() @staticmethod From 5cb974a7a75f4ff7d3078148890546c0a414e8e2 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 15:21:46 +0800 Subject: [PATCH 0247/1089] more --- python/sglang/srt/managers/expert_location.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 8704abd64a1..7d436802f2d 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -19,7 +19,7 @@ class ExpertLocationMetadata: def from_model_config(model_config: ModelConfig): model_class, _ = get_model_architecture(model_config) if hasattr(model_class, "get_expert_location_metadata"): - return model_class.get_expert_location_metadata() + return model_class.get_expert_location_metadata(model_config.hf_config) return ExpertLocationMetadata._init_dummy() @staticmethod From 4571aa1d5af21f867413559dba28ea97de8d2a97 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 15:22:08 +0800 Subject: [PATCH 0248/1089] more --- python/sglang/srt/models/deepseek_v2.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 0bb289427c7..5dcdff1499a 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1608,10 +1608,11 @@ def set_embed_and_head(self, embed, head): torch.cuda.empty_cache() torch.cuda.synchronize() - def get_expert_location_metadata(self): + @classmethod + def get_expert_location_metadata(cls, config): return ExpertLocationMetadata.init_new( - num_layers=self.config.num_hidden_layers, - num_logical_experts=self.config.n_routed_experts, + num_layers=config.num_hidden_layers, + num_logical_experts=config.n_routed_experts, ) From f61cfc8f1dacc0024c7a1b1c8d090cf7e2c89cdd Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 15:22:32 +0800 Subject: [PATCH 0249/1089] more --- python/sglang/srt/models/deepseek_v2.py | 62 ++++++++++++------------- 1 file changed, 30 insertions(+), 32 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 5dcdff1499a..0dafbe5ed08 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -22,10 +22,6 @@ import torch import torch.nn.functional as F -from torch import nn -from tqdm import tqdm -from transformers import PretrainedConfig - from sglang.srt.distributed import ( get_tensor_model_parallel_world_size, parallel_state, @@ -74,6 +70,9 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import DeepEPMode, add_prefix, is_cuda, is_hip +from torch import nn +from tqdm import tqdm +from transformers import PretrainedConfig _is_hip = is_hip() _is_cuda = is_cuda() @@ -90,7 +89,6 @@ decode_attention_fwd_grouped_rope, ) - logger = logging.getLogger(__name__) @@ -412,7 +410,7 @@ def __init__( self.num_heads = num_heads assert num_heads % attn_tp_size == 0 self.num_local_heads = num_heads // attn_tp_size - self.scaling = self.qk_head_dim**-0.5 + self.scaling = self.qk_head_dim ** -0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings @@ -525,12 +523,12 @@ def forward( kv = self.kv_b_proj(kv_a)[0] kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - k_pe = latent_cache[:, :, self.kv_lora_rank :] + k_pe = latent_cache[:, :, self.kv_lora_rank:] q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - q[..., self.qk_nope_head_dim :] = q_pe + q[..., self.qk_nope_head_dim:] = q_pe k = torch.empty_like(q) k[..., : self.qk_nope_head_dim] = k_nope - k[..., self.qk_nope_head_dim :] = k_pe + k[..., self.qk_nope_head_dim:] = k_pe q = torch.nn.functional.pad(q, [0, 256 - self.qk_head_dim], value=0).view( -1, self.num_local_heads * 256 ) @@ -542,8 +540,8 @@ def forward( ) attn_output = self.attn(q, k, v, forward_batch) attn_output = attn_output.view(-1, self.num_local_heads, 256)[ - ..., : self.v_head_dim - ].reshape(-1, self.num_local_heads * self.v_head_dim) + ..., : self.v_head_dim + ].reshape(-1, self.num_local_heads * self.v_head_dim) output, _ = self.o_proj(attn_output) return output @@ -584,7 +582,7 @@ def __init__( self.num_heads = num_heads assert num_heads % attn_tp_size == 0 self.num_local_heads = num_heads // attn_tp_size - self.scaling = self.qk_head_dim**-0.5 + self.scaling = self.qk_head_dim ** -0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings @@ -771,16 +769,16 @@ def forward_normal( kv = self.kv_b_proj(kv_a)[0] kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) k_nope = kv[..., : self.qk_nope_head_dim] - v = kv[..., self.qk_nope_head_dim :] - k_pe = latent_cache[:, :, self.kv_lora_rank :] + v = kv[..., self.qk_nope_head_dim:] + k_pe = latent_cache[:, :, self.kv_lora_rank:] q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - q[..., self.qk_nope_head_dim :] = q_pe + q[..., self.qk_nope_head_dim:] = q_pe k = torch.empty_like(q) k[..., : self.qk_nope_head_dim] = k_nope - k[..., self.qk_nope_head_dim :] = k_pe + k[..., self.qk_nope_head_dim:] = k_pe latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1) - latent_cache[:, :, self.kv_lora_rank :] = k_pe + latent_cache[:, :, self.kv_lora_rank:] = k_pe # Save latent cache forward_batch.token_to_kv_pool.set_kv_buffer( @@ -833,11 +831,11 @@ def forward_absorb( v_input = self.kv_a_layernorm(v_input.contiguous()).unsqueeze(1) k_input = latent_cache.unsqueeze(1) k_input[..., : self.kv_lora_rank] = v_input - k_pe = k_input[..., self.kv_lora_rank :] + k_pe = k_input[..., self.kv_lora_rank:] q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - q_input[..., self.kv_lora_rank :] = q_pe - k_input[..., self.kv_lora_rank :] = k_pe + q_input[..., self.kv_lora_rank:] = q_pe + k_input[..., self.kv_lora_rank:] = k_pe attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch) attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank) @@ -913,15 +911,15 @@ def forward_absorb_fused_mla_rope( k_input[..., : self.kv_lora_rank] = v_input if not enable_rope_fusion: - k_pe = k_input[..., self.kv_lora_rank :] + k_pe = k_input[..., self.kv_lora_rank:] q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - q_input[..., self.kv_lora_rank :] = q_pe - k_input[..., self.kv_lora_rank :] = k_pe + q_input[..., self.kv_lora_rank:] = q_pe + k_input[..., self.kv_lora_rank:] = k_pe k_pe_output = None else: - k_pe_output = torch.empty_like(k_input[..., self.kv_lora_rank :]) + k_pe_output = torch.empty_like(k_input[..., self.kv_lora_rank:]) - q_input[..., self.kv_lora_rank :] = q_pe + q_input[..., self.kv_lora_rank:] = q_pe # attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch) # Use Fused ROPE with use_rope=OFF. @@ -978,7 +976,7 @@ def forward_absorb_fused_mla_rope( ) if enable_rope_fusion: - k_input[..., self.kv_lora_rank :] = k_pe_output + k_input[..., self.kv_lora_rank:] = k_pe_output forward_batch.token_to_kv_pool.set_kv_buffer( self.attn_mqa, forward_batch.out_cache_loc, k_input, None ) @@ -1496,7 +1494,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): self.config.moe_layer_freq, ), desc=f"Cloning {self.n_share_experts_fusion} " - "replicas of the shared expert into MoE", + "replicas of the shared expert into MoE", ): for num_repeat in range(self.n_share_experts_fusion): for suffix in suffix_list: @@ -1527,11 +1525,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", num_experts=self.config.n_routed_experts - + ( - self.n_share_experts_fusion - if self.n_share_experts_fusion is not None - else 0 - ), + + ( + self.n_share_experts_fusion + if self.n_share_experts_fusion is not None + else 0 + ), ) params_dict = dict(self.named_parameters()) From 9ec9519b2ebc4dd3ae8bc6fe67db6e6f1a3f82fa Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 15:22:47 +0800 Subject: [PATCH 0250/1089] more --- python/sglang/srt/models/qwen2_moe.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 9c8ffce9a7e..d724d1afa65 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -487,10 +487,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ) weight_loader(param, loaded_weight) - def get_expert_location_metadata(self): + @classmethod + def get_expert_location_metadata(cls, config): return ExpertLocationMetadata.init_new( - num_layers=self.config.num_hidden_layers, - num_logical_experts=self.config.num_experts, + num_layers=config.num_hidden_layers, + num_logical_experts=config.num_experts, ) From 114c1387243b3a33a7d41aab4490a513f5fe9eb2 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 15:29:11 +0800 Subject: [PATCH 0251/1089] more --- python/sglang/srt/layers/moe/ep_moe/layer.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index dfecb63d940..05498db77c7 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -409,6 +409,16 @@ def weight_loader( weight_name: str, shard_id: str, expert_id: int, + ) -> None: + TODO + + def _weight_loader_physical( + self, + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + weight_name: str, + shard_id: str, + expert_id: int, ) -> None: if expert_id < self.start_expert_id or expert_id > self.end_expert_id: return From b14f1513e13aebd51ad26a8ec2b986b7044e3b1c Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 15:29:50 +0800 Subject: [PATCH 0252/1089] more --- python/sglang/srt/layers/moe/ep_moe/layer.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 05498db77c7..be301b48964 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -258,7 +258,7 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): BLOCK_SIZE=512, ) - seg_indptr_cur_rank = seg_indptr[self.start_expert_id : self.end_expert_id + 2] + seg_indptr_cur_rank = seg_indptr[self.start_expert_id: self.end_expert_id + 2] weight_indices_cur_rank = torch.arange( 0, self.num_experts_per_partition, @@ -410,7 +410,9 @@ def weight_loader( shard_id: str, expert_id: int, ) -> None: - TODO + physical_expert_ids = TODO + for physical_expert_ids in physical_expert_ids: + TODO def _weight_loader_physical( self, @@ -445,7 +447,7 @@ def _weight_loader_physical( elif shard_id == "w1": param.data[expert_id][: self.intermediate_size, :] = loaded_weight elif shard_id == "w3": - param.data[expert_id][self.intermediate_size :, :] = loaded_weight + param.data[expert_id][self.intermediate_size:, :] = loaded_weight else: raise ValueError(f"Expected shard_id w1,w2 or w3 but got {shard_id}") @@ -477,11 +479,11 @@ def _load_fp8_scale( block_n, block_k = self.block_shape[0], self.block_shape[1] if shard_id == "w1": param_data[expert_id][ - : (self.intermediate_size + block_n - 1) // block_n, : + : (self.intermediate_size + block_n - 1) // block_n, : ] = loaded_weight elif shard_id == "w3": param_data[expert_id][ - (self.intermediate_size + block_n - 1) // block_n :, : + (self.intermediate_size + block_n - 1) // block_n:, : ] = loaded_weight else: # w2 param_data[expert_id] = loaded_weight From e267065da1519db664ab662a8762db4d9ca40629 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 15:30:30 +0800 Subject: [PATCH 0253/1089] more --- python/sglang/srt/layers/moe/ep_moe/layer.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index be301b48964..b29004fb92a 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -411,8 +411,11 @@ def weight_loader( expert_id: int, ) -> None: physical_expert_ids = TODO - for physical_expert_ids in physical_expert_ids: - TODO + for physical_expert_id in physical_expert_ids: + self._weight_loader_physical( + param=param, loaded_weight=loaded_weight, weight_name=weight_name, shard_id=shard_id, + expert_id=physical_expert_id + ) def _weight_loader_physical( self, From 4033e929655bb9dc0314f2b68f827cb4c6cc61f2 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 15:31:11 +0800 Subject: [PATCH 0254/1089] more --- python/sglang/srt/layers/moe/ep_moe/layer.py | 2 +- python/sglang/srt/managers/expert_location.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index b29004fb92a..5b8b5df943a 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -410,7 +410,7 @@ def weight_loader( shard_id: str, expert_id: int, ) -> None: - physical_expert_ids = TODO + physical_expert_ids = expert_location_metadata.logical_to_global_physical(expert_id) for physical_expert_id in physical_expert_ids: self._weight_loader_physical( param=param, loaded_weight=loaded_weight, weight_name=weight_name, shard_id=shard_id, diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 7d436802f2d..6d220a82963 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -26,7 +26,7 @@ def from_model_config(model_config: ModelConfig): def init_new(num_layers: int, num_logical_experts: int): # TODO handle more complex cases like duplicating experts on different GPUs num_local_physical_experts = ( - num_logical_experts // get_tensor_model_parallel_world_size() + num_logical_experts // get_tensor_model_parallel_world_size() ) num_physical_experts = num_logical_experts @@ -45,13 +45,16 @@ def _init_dummy(): return ExpertLocationMetadata.init_new(num_layers=1, num_logical_experts=1) def local_physical_to_global_physical( - self, rank: int, local_physical_expert_index: int + self, rank: int, local_physical_expert_index: int ): return self.num_local_physical_experts * rank + local_physical_expert_index def global_physical_to_local_physical(self, global_physical_expert_index: int): return global_physical_expert_index % self.num_local_physical_experts + def logical_to_global_physical(self, logical_expert_id: int): + return TODO + def _create_vanilla_physical_to_logical_map(num_layers: int, num_physical_experts: int): return torch.arange(0, num_physical_experts).repeat(num_layers, 1) From 384cfd4860ff38c03735588c093942b69a95b480 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 15:31:24 +0800 Subject: [PATCH 0255/1089] more --- python/sglang/srt/managers/expert_location.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 6d220a82963..fc1b1ebeb34 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -1,7 +1,6 @@ from dataclasses import dataclass import torch - from sglang.srt.configs.model_config import ModelConfig from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.model_loader import get_model_architecture @@ -53,7 +52,7 @@ def global_physical_to_local_physical(self, global_physical_expert_index: int): return global_physical_expert_index % self.num_local_physical_experts def logical_to_global_physical(self, logical_expert_id: int): - return TODO + return logical_expert_id # TODO support arbitrary mapping def _create_vanilla_physical_to_logical_map(num_layers: int, num_physical_experts: int): From 811461819538d9cb03ddda61840bdbc8e3527c70 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 15:31:53 +0800 Subject: [PATCH 0256/1089] more --- python/sglang/srt/managers/expert_location.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index fc1b1ebeb34..0df497d8bbb 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -52,7 +52,7 @@ def global_physical_to_local_physical(self, global_physical_expert_index: int): return global_physical_expert_index % self.num_local_physical_experts def logical_to_global_physical(self, logical_expert_id: int): - return logical_expert_id # TODO support arbitrary mapping + return logical_expert_id # TODO add a logical_to_physical_map def _create_vanilla_physical_to_logical_map(num_layers: int, num_physical_experts: int): From fc11b7a5951bcf0e810765ab18c6ecdf22a655ce Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 15:34:34 +0800 Subject: [PATCH 0257/1089] fmt --- python/sglang/srt/entrypoints/engine.py | 5 +- python/sglang/srt/layers/moe/ep_moe/layer.py | 19 +++--- python/sglang/srt/managers/expert_location.py | 5 +- .../sglang/srt/managers/tokenizer_manager.py | 5 +- python/sglang/srt/models/deepseek_v2.py | 61 ++++++++++--------- 5 files changed, 53 insertions(+), 42 deletions(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 857b283426e..f149a276fe0 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -30,6 +30,7 @@ import zmq import zmq.asyncio from PIL.Image import Image + from sglang.srt.configs.model_config import ModelConfig from sglang.srt.managers.expert_location import ExpertLocationMetadata @@ -570,7 +571,9 @@ def _launch_subprocesses( detoken_proc.start() # Launch tokenizer process - tokenizer_manager = TokenizerManager(server_args, port_args, expert_location_metadata) + tokenizer_manager = TokenizerManager( + server_args, port_args, expert_location_metadata + ) if server_args.chat_template: load_chat_template_for_openai_api( tokenizer_manager, server_args.chat_template, server_args.model_path diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 5b8b5df943a..cffc15a4269 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -258,7 +258,7 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): BLOCK_SIZE=512, ) - seg_indptr_cur_rank = seg_indptr[self.start_expert_id: self.end_expert_id + 2] + seg_indptr_cur_rank = seg_indptr[self.start_expert_id : self.end_expert_id + 2] weight_indices_cur_rank = torch.arange( 0, self.num_experts_per_partition, @@ -410,11 +410,16 @@ def weight_loader( shard_id: str, expert_id: int, ) -> None: - physical_expert_ids = expert_location_metadata.logical_to_global_physical(expert_id) + physical_expert_ids = expert_location_metadata.logical_to_global_physical( + expert_id + ) for physical_expert_id in physical_expert_ids: self._weight_loader_physical( - param=param, loaded_weight=loaded_weight, weight_name=weight_name, shard_id=shard_id, - expert_id=physical_expert_id + param=param, + loaded_weight=loaded_weight, + weight_name=weight_name, + shard_id=shard_id, + expert_id=physical_expert_id, ) def _weight_loader_physical( @@ -450,7 +455,7 @@ def _weight_loader_physical( elif shard_id == "w1": param.data[expert_id][: self.intermediate_size, :] = loaded_weight elif shard_id == "w3": - param.data[expert_id][self.intermediate_size:, :] = loaded_weight + param.data[expert_id][self.intermediate_size :, :] = loaded_weight else: raise ValueError(f"Expected shard_id w1,w2 or w3 but got {shard_id}") @@ -482,11 +487,11 @@ def _load_fp8_scale( block_n, block_k = self.block_shape[0], self.block_shape[1] if shard_id == "w1": param_data[expert_id][ - : (self.intermediate_size + block_n - 1) // block_n, : + : (self.intermediate_size + block_n - 1) // block_n, : ] = loaded_weight elif shard_id == "w3": param_data[expert_id][ - (self.intermediate_size + block_n - 1) // block_n:, : + (self.intermediate_size + block_n - 1) // block_n :, : ] = loaded_weight else: # w2 param_data[expert_id] = loaded_weight diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 0df497d8bbb..60ae8b4fe2b 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -1,6 +1,7 @@ from dataclasses import dataclass import torch + from sglang.srt.configs.model_config import ModelConfig from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.model_loader import get_model_architecture @@ -25,7 +26,7 @@ def from_model_config(model_config: ModelConfig): def init_new(num_layers: int, num_logical_experts: int): # TODO handle more complex cases like duplicating experts on different GPUs num_local_physical_experts = ( - num_logical_experts // get_tensor_model_parallel_world_size() + num_logical_experts // get_tensor_model_parallel_world_size() ) num_physical_experts = num_logical_experts @@ -44,7 +45,7 @@ def _init_dummy(): return ExpertLocationMetadata.init_new(num_layers=1, num_logical_experts=1) def local_physical_to_global_physical( - self, rank: int, local_physical_expert_index: int + self, rank: int, local_physical_expert_index: int ): return self.num_local_physical_experts * rank + local_physical_expert_index diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 5274315a216..93a9507bf19 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -45,6 +45,7 @@ import zmq import zmq.asyncio from fastapi import BackgroundTasks + from sglang.srt.aio_rwlock import RWLock from sglang.srt.configs.model_config import ModelConfig from sglang.srt.disaggregation.conn import KVBootstrapServer @@ -961,8 +962,8 @@ def _handle_batch_output( elif isinstance(recv_obj, BatchTokenIDOut): if self.server_args.stream_output and state.obj.stream: output_token_ids = recv_obj.output_ids[i][ - state.last_output_offset: - ] + state.last_output_offset : + ] state.last_output_offset = len(recv_obj.output_ids[i]) else: output_token_ids = recv_obj.output_ids[i] diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 0dafbe5ed08..2380031fc73 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -22,6 +22,10 @@ import torch import torch.nn.functional as F +from torch import nn +from tqdm import tqdm +from transformers import PretrainedConfig + from sglang.srt.distributed import ( get_tensor_model_parallel_world_size, parallel_state, @@ -70,9 +74,6 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import DeepEPMode, add_prefix, is_cuda, is_hip -from torch import nn -from tqdm import tqdm -from transformers import PretrainedConfig _is_hip = is_hip() _is_cuda = is_cuda() @@ -410,7 +411,7 @@ def __init__( self.num_heads = num_heads assert num_heads % attn_tp_size == 0 self.num_local_heads = num_heads // attn_tp_size - self.scaling = self.qk_head_dim ** -0.5 + self.scaling = self.qk_head_dim**-0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings @@ -523,12 +524,12 @@ def forward( kv = self.kv_b_proj(kv_a)[0] kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - k_pe = latent_cache[:, :, self.kv_lora_rank:] + k_pe = latent_cache[:, :, self.kv_lora_rank :] q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - q[..., self.qk_nope_head_dim:] = q_pe + q[..., self.qk_nope_head_dim :] = q_pe k = torch.empty_like(q) k[..., : self.qk_nope_head_dim] = k_nope - k[..., self.qk_nope_head_dim:] = k_pe + k[..., self.qk_nope_head_dim :] = k_pe q = torch.nn.functional.pad(q, [0, 256 - self.qk_head_dim], value=0).view( -1, self.num_local_heads * 256 ) @@ -540,8 +541,8 @@ def forward( ) attn_output = self.attn(q, k, v, forward_batch) attn_output = attn_output.view(-1, self.num_local_heads, 256)[ - ..., : self.v_head_dim - ].reshape(-1, self.num_local_heads * self.v_head_dim) + ..., : self.v_head_dim + ].reshape(-1, self.num_local_heads * self.v_head_dim) output, _ = self.o_proj(attn_output) return output @@ -582,7 +583,7 @@ def __init__( self.num_heads = num_heads assert num_heads % attn_tp_size == 0 self.num_local_heads = num_heads // attn_tp_size - self.scaling = self.qk_head_dim ** -0.5 + self.scaling = self.qk_head_dim**-0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings @@ -769,16 +770,16 @@ def forward_normal( kv = self.kv_b_proj(kv_a)[0] kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) k_nope = kv[..., : self.qk_nope_head_dim] - v = kv[..., self.qk_nope_head_dim:] - k_pe = latent_cache[:, :, self.kv_lora_rank:] + v = kv[..., self.qk_nope_head_dim :] + k_pe = latent_cache[:, :, self.kv_lora_rank :] q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - q[..., self.qk_nope_head_dim:] = q_pe + q[..., self.qk_nope_head_dim :] = q_pe k = torch.empty_like(q) k[..., : self.qk_nope_head_dim] = k_nope - k[..., self.qk_nope_head_dim:] = k_pe + k[..., self.qk_nope_head_dim :] = k_pe latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1) - latent_cache[:, :, self.kv_lora_rank:] = k_pe + latent_cache[:, :, self.kv_lora_rank :] = k_pe # Save latent cache forward_batch.token_to_kv_pool.set_kv_buffer( @@ -831,11 +832,11 @@ def forward_absorb( v_input = self.kv_a_layernorm(v_input.contiguous()).unsqueeze(1) k_input = latent_cache.unsqueeze(1) k_input[..., : self.kv_lora_rank] = v_input - k_pe = k_input[..., self.kv_lora_rank:] + k_pe = k_input[..., self.kv_lora_rank :] q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - q_input[..., self.kv_lora_rank:] = q_pe - k_input[..., self.kv_lora_rank:] = k_pe + q_input[..., self.kv_lora_rank :] = q_pe + k_input[..., self.kv_lora_rank :] = k_pe attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch) attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank) @@ -911,15 +912,15 @@ def forward_absorb_fused_mla_rope( k_input[..., : self.kv_lora_rank] = v_input if not enable_rope_fusion: - k_pe = k_input[..., self.kv_lora_rank:] + k_pe = k_input[..., self.kv_lora_rank :] q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - q_input[..., self.kv_lora_rank:] = q_pe - k_input[..., self.kv_lora_rank:] = k_pe + q_input[..., self.kv_lora_rank :] = q_pe + k_input[..., self.kv_lora_rank :] = k_pe k_pe_output = None else: - k_pe_output = torch.empty_like(k_input[..., self.kv_lora_rank:]) + k_pe_output = torch.empty_like(k_input[..., self.kv_lora_rank :]) - q_input[..., self.kv_lora_rank:] = q_pe + q_input[..., self.kv_lora_rank :] = q_pe # attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch) # Use Fused ROPE with use_rope=OFF. @@ -976,7 +977,7 @@ def forward_absorb_fused_mla_rope( ) if enable_rope_fusion: - k_input[..., self.kv_lora_rank:] = k_pe_output + k_input[..., self.kv_lora_rank :] = k_pe_output forward_batch.token_to_kv_pool.set_kv_buffer( self.attn_mqa, forward_batch.out_cache_loc, k_input, None ) @@ -1494,7 +1495,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): self.config.moe_layer_freq, ), desc=f"Cloning {self.n_share_experts_fusion} " - "replicas of the shared expert into MoE", + "replicas of the shared expert into MoE", ): for num_repeat in range(self.n_share_experts_fusion): for suffix in suffix_list: @@ -1525,11 +1526,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", num_experts=self.config.n_routed_experts - + ( - self.n_share_experts_fusion - if self.n_share_experts_fusion is not None - else 0 - ), + + ( + self.n_share_experts_fusion + if self.n_share_experts_fusion is not None + else 0 + ), ) params_dict = dict(self.named_parameters()) From df0611f75e8a2fa9ec87ffa731c9f719a2c70709 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 15:36:31 +0800 Subject: [PATCH 0258/1089] more --- python/sglang/srt/entrypoints/engine.py | 4 ++-- python/sglang/srt/managers/scheduler.py | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 857b283426e..4872b407d1f 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -521,7 +521,7 @@ def _launch_subprocesses( ) proc = mp.Process( target=run_scheduler_process, - args=(server_args, port_args, gpu_id, tp_rank, None, writer), + args=(server_args, port_args,expert_location_metadata, gpu_id, tp_rank, None, writer), ) with memory_saver_adapter.configure_subprocess(): proc.start() @@ -533,7 +533,7 @@ def _launch_subprocesses( scheduler_pipe_readers = [reader] proc = mp.Process( target=run_data_parallel_controller_process, - args=(server_args, port_args, writer), + args=(server_args, port_args,expert_location_metadata, writer), ) proc.start() scheduler_procs.append(proc) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 1d533791fc3..e4d05b4cf13 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -54,6 +54,7 @@ from sglang.srt.layers.dp_attention import compute_dp_attention_world_info from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.expert_distribution import expert_distribution_recorder +from sglang.srt.managers.expert_location import ExpertLocationMetadata from sglang.srt.managers.io_struct import ( AbortReq, CloseSessionReqInput, @@ -2005,6 +2006,7 @@ def _import_static_state(model, static_params): def run_scheduler_process( server_args: ServerArgs, port_args: PortArgs, + expert_location_metadata: ExpertLocationMetadata, gpu_id: int, tp_rank: int, dp_rank: Optional[int], From 204351de328645aa2088fa405abdd19a013de571 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 15:37:19 +0800 Subject: [PATCH 0259/1089] more --- python/sglang/srt/entrypoints/engine.py | 4 +-- .../srt/managers/data_parallel_controller.py | 24 ++++++++++-------- python/sglang/srt/managers/scheduler.py | 25 +++++++++---------- 3 files changed, 28 insertions(+), 25 deletions(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 4872b407d1f..43343864c11 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -521,7 +521,7 @@ def _launch_subprocesses( ) proc = mp.Process( target=run_scheduler_process, - args=(server_args, port_args,expert_location_metadata, gpu_id, tp_rank, None, writer), + args=(server_args, port_args, expert_location_metadata, gpu_id, tp_rank, None, writer), ) with memory_saver_adapter.configure_subprocess(): proc.start() @@ -533,7 +533,7 @@ def _launch_subprocesses( scheduler_pipe_readers = [reader] proc = mp.Process( target=run_data_parallel_controller_process, - args=(server_args, port_args,expert_location_metadata, writer), + args=(server_args, port_args, expert_location_metadata, writer), ) proc.start() scheduler_procs.append(proc) diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index fb0264a6ea9..5b2421aab3d 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -22,8 +22,8 @@ import psutil import setproctitle import zmq - from sglang.srt.layers.dp_attention import compute_dp_attention_world_info +from sglang.srt.managers.expert_location import ExpertLocationMetadata from sglang.srt.managers.io_struct import ( TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, @@ -54,7 +54,8 @@ def from_str(cls, method: str): class DataParallelController: """A controller that dispatches requests to multiple data parallel workers.""" - def __init__(self, server_args: ServerArgs, port_args: PortArgs) -> None: + def __init__(self, server_args: ServerArgs, port_args: PortArgs, + expert_location_metadata: ExpertLocationMetadata) -> None: # Parse args self.max_total_num_tokens = None self.server_args = server_args @@ -83,10 +84,10 @@ def __init__(self, server_args: ServerArgs, port_args: PortArgs) -> None: self.workers = [None] * server_args.dp_size if server_args.enable_dp_attention: - dp_port_args = self.launch_dp_attention_schedulers(server_args, port_args) + dp_port_args = self.launch_dp_attention_schedulers(server_args, port_args, expert_location_metadata) self.control_message_step = server_args.tp_size else: - dp_port_args = self.launch_dp_schedulers(server_args, port_args) + dp_port_args = self.launch_dp_schedulers(server_args, port_args, expert_location_metadata) self.control_message_step = 1 # Only node rank 0 runs the real data parallel controller that dispatches the requests. @@ -101,7 +102,7 @@ def __init__(self, server_args: ServerArgs, port_args: PortArgs) -> None: self.max_req_input_len = None - def launch_dp_schedulers(self, server_args, port_args): + def launch_dp_schedulers(self, server_args, port_args, expert_location_metadata): base_gpu_id = 0 threads = [] @@ -124,7 +125,7 @@ def launch_dp_schedulers(self, server_args, port_args): # Create a thread for each worker thread = threading.Thread( target=self.launch_tensor_parallel_group_thread, - args=(server_args, tmp_port_args, base_gpu_id, dp_rank, ready_event), + args=(server_args, tmp_port_args, expert_location_metadata, base_gpu_id, dp_rank, ready_event), ) threads.append(thread) base_gpu_id += server_args.tp_size * server_args.gpu_id_step @@ -145,11 +146,12 @@ def launch_tensor_parallel_group_thread( self, server_args: ServerArgs, port_args: PortArgs, + expert_location_metadata: ExpertLocationMetadata, base_gpu_id: int, dp_rank: int, ready_event: threading.Event, ): - self.launch_tensor_parallel_group(server_args, port_args, base_gpu_id, dp_rank) + self.launch_tensor_parallel_group(server_args, port_args, expert_location_metadata, base_gpu_id, dp_rank) ready_event.set() # This thread cannot be closed because otherwise the `kill_itself_when_parent_died` @@ -157,8 +159,8 @@ def launch_tensor_parallel_group_thread( while True: pass - def launch_dp_attention_schedulers(self, server_args, port_args): - self.launch_tensor_parallel_group(server_args, port_args, 0, None) + def launch_dp_attention_schedulers(self, server_args, port_args, expert_location_metadata): + self.launch_tensor_parallel_group(server_args, port_args, expert_location_metadata, 0, None) dp_port_args = [] for dp_rank in range(server_args.dp_size): dp_port_args.append(PortArgs.init_new(server_args, dp_rank)) @@ -168,6 +170,7 @@ def launch_tensor_parallel_group( self, server_args: ServerArgs, port_args: PortArgs, + expert_location_metadata: ExpertLocationMetadata, base_gpu_id: int, dp_rank: int, ): @@ -252,6 +255,7 @@ def event_loop(self): def run_data_parallel_controller_process( server_args: ServerArgs, port_args: PortArgs, + expert_location_metadata: ExpertLocationMetadata, pipe_writer, ): setproctitle.setproctitle("sglang::data_parallel_controller") @@ -259,7 +263,7 @@ def run_data_parallel_controller_process( parent_process = psutil.Process().parent() try: - controller = DataParallelController(server_args, port_args) + controller = DataParallelController(server_args, port_args, expert_location_metadata) pipe_writer.send( { "status": "ready", diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index e4d05b4cf13..84e11eea8f0 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -32,8 +32,6 @@ import setproctitle import torch import zmq -from torch.distributed import barrier - from sglang.global_config import global_config from sglang.srt.configs.model_config import ModelConfig from sglang.srt.constrained.base_grammar_backend import create_grammar_backend @@ -132,6 +130,7 @@ suppress_other_loggers, ) from sglang.utils import TypeBasedDispatcher, get_exception_traceback +from torch.distributed import barrier logger = logging.getLogger(__name__) @@ -370,8 +369,8 @@ def __init__( 1.0, ) self.new_token_ratio_decay = ( - self.init_new_token_ratio - self.min_new_token_ratio - ) / global_config.default_new_token_ratio_decay_steps + self.init_new_token_ratio - self.min_new_token_ratio + ) / global_config.default_new_token_ratio_decay_steps self.new_token_ratio = self.init_new_token_ratio # Init watchdog thread @@ -1250,10 +1249,10 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: if ( self.lora_paths and len( - lora_set - | set([req.lora_path for req in adder.can_run_list]) - | set([req.lora_path]) - ) + lora_set + | set([req.lora_path for req in adder.can_run_list]) + | set([req.lora_path]) + ) > self.max_loras_per_batch ): self.running_batch.batch_is_full = True @@ -1278,9 +1277,9 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: self.running_batch.batch_is_full = len( adder.can_run_list ) > 0 or ( - self.running_batch is not None - and not self.running_batch.is_empty() - ) + self.running_batch is not None + and not self.running_batch.is_empty() + ) else: self.running_batch.batch_is_full = True break @@ -1512,8 +1511,8 @@ def prepare_dp_attn_batch_raw( # We should have at least 1 token for sample in every case. max(extend_len - logprob_start_len, 1) for logprob_start_len, extend_len in zip( - local_batch.extend_logprob_start_lens, local_batch.extend_lens - ) + local_batch.extend_logprob_start_lens, local_batch.extend_lens + ) ] ) From 78af10cd1020207ba4127c532d6d9ba205bf15ee Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 15:37:38 +0800 Subject: [PATCH 0260/1089] more --- python/sglang/srt/managers/data_parallel_controller.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index 5b2421aab3d..0223a8067eb 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -209,7 +209,7 @@ def launch_tensor_parallel_group( ) proc = mp.Process( target=run_scheduler_process, - args=(server_args, rank_port_args, gpu_id, tp_rank, dp_rank, writer), + args=(server_args, rank_port_args, expert_location_metadata, gpu_id, tp_rank, dp_rank, writer), ) proc.start() self.scheduler_procs.append(proc) From 5df2e41ed159f51dcddd93991066d505838171e3 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 15:38:45 +0800 Subject: [PATCH 0261/1089] more --- python/sglang/srt/managers/scheduler.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 84e11eea8f0..e4ae286200f 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -165,6 +165,7 @@ def __init__( self, server_args: ServerArgs, port_args: PortArgs, + expert_location_metadata: ExpertLocationMetadata, gpu_id: int, tp_rank: int, dp_rank: Optional[int], @@ -256,6 +257,7 @@ def __init__( self.tp_worker = TpWorkerClass( server_args=server_args, + expert_location_metadata=expert_location_metadata, gpu_id=gpu_id, tp_rank=tp_rank, dp_rank=dp_rank, @@ -2037,7 +2039,7 @@ def run_scheduler_process( # Create a scheduler and run the event loop try: - scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank) + scheduler = Scheduler(server_args, port_args, expert_location_metadata, gpu_id, tp_rank, dp_rank) pipe_writer.send( { "status": "ready", From c302cf39e97a9dab337076f4c2fe290b56fe03a4 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 15:39:20 +0800 Subject: [PATCH 0262/1089] more --- python/sglang/srt/managers/tp_worker.py | 3 +++ python/sglang/srt/managers/tp_worker_overlap_thread.py | 4 +++- python/sglang/srt/model_executor/model_runner.py | 4 +++- 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index c8388da75ed..984a7716d37 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -22,6 +22,7 @@ from sglang.srt.configs.model_config import ModelConfig from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.layers.logits_processor import LogitsProcessorOutput +from sglang.srt.managers.expert_location import ExpertLocationMetadata from sglang.srt.managers.io_struct import ( GetWeightsByNameReqInput, InitWeightsUpdateGroupReqInput, @@ -45,6 +46,7 @@ class TpModelWorker: def __init__( self, server_args: ServerArgs, + expert_location_metadata: ExpertLocationMetadata, gpu_id: int, tp_rank: int, dp_rank: Optional[int], @@ -67,6 +69,7 @@ def __init__( ) self.model_runner = ModelRunner( model_config=self.model_config, + expert_location_metadata=expert_location_metadata, mem_fraction_static=server_args.mem_fraction_static, gpu_id=gpu_id, tp_rank=tp_rank, diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index fb4fdc6d55f..3579c986551 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -23,6 +23,7 @@ import psutil import torch +from sglang.srt.managers.expert_location import ExpertLocationMetadata from sglang.srt.managers.io_struct import ( GetWeightsByNameReqInput, InitWeightsUpdateGroupReqInput, @@ -54,13 +55,14 @@ class TpModelWorkerClient: def __init__( self, server_args: ServerArgs, + expert_location_metadata: ExpertLocationMetadata, gpu_id: int, tp_rank: int, dp_rank: Optional[int], nccl_port: int, ): # Load the model - self.worker = TpModelWorker(server_args, gpu_id, tp_rank, dp_rank, nccl_port) + self.worker = TpModelWorker(server_args,expert_location_metadata, gpu_id, tp_rank, dp_rank, nccl_port) self.max_running_requests = self.worker.max_running_requests self.device = self.worker.device self.gpu_id = gpu_id diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 201c1c77d5e..4227eec69ab 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -97,6 +97,7 @@ class ModelRunner: def __init__( self, model_config: ModelConfig, + expert_location_metadata: ExpertLocationMetadata, mem_fraction_static: float, gpu_id: int, tp_rank: int, @@ -178,6 +179,8 @@ def __init__( # Get memory before model loading min_per_gpu_memory = self.init_torch_distributed() + self.expert_location_metadata = expert_location_metadata + # If it is a draft model tp_group can be different. self.initialize(min_per_gpu_memory) @@ -191,7 +194,6 @@ def initialize(self, min_per_gpu_memory: float): self.sampler = Sampler() self.load_model() - self.expert_location_metadata = TODO expert_distribution_recorder.initialize( server_args, self.expert_location_metadata, From bc4ace9660a798d5ce25b6b9cc00ff8cd730b51a Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 15:41:52 +0800 Subject: [PATCH 0263/1089] more --- python/sglang/srt/managers/expert_location.py | 7 +++++++ python/sglang/srt/model_executor/model_runner.py | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 7d436802f2d..c7e3a61288d 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -52,6 +52,13 @@ def local_physical_to_global_physical( def global_physical_to_local_physical(self, global_physical_expert_index: int): return global_physical_expert_index % self.num_local_physical_experts + def clone(self): + return ExpertLocationMetadata( + num_layers=self.num_layers, + num_local_physical_experts=self.num_local_physical_experts, + num_logical_experts=self.num_logical_experts, + physical_to_logical_map=self.physical_to_logical_map.clone(), + ) def _create_vanilla_physical_to_logical_map(num_layers: int, num_physical_experts: int): return torch.arange(0, num_physical_experts).repeat(num_layers, 1) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 4227eec69ab..36b691eb4ed 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -179,7 +179,7 @@ def __init__( # Get memory before model loading min_per_gpu_memory = self.init_torch_distributed() - self.expert_location_metadata = expert_location_metadata + self.expert_location_metadata = expert_location_metadata.clone() # If it is a draft model tp_group can be different. self.initialize(min_per_gpu_memory) From a9f5f54bed2fffdbf900bfca34853ca849007bb1 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 15:42:42 +0800 Subject: [PATCH 0264/1089] fmt --- python/sglang/srt/entrypoints/engine.py | 15 ++++- .../srt/managers/data_parallel_controller.py | 52 +++++++++++++--- python/sglang/srt/managers/expert_location.py | 1 + python/sglang/srt/managers/scheduler.py | 29 +++++---- .../sglang/srt/managers/tokenizer_manager.py | 5 +- .../srt/managers/tp_worker_overlap_thread.py | 4 +- python/sglang/srt/models/deepseek_v2.py | 61 ++++++++++--------- 7 files changed, 109 insertions(+), 58 deletions(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 43343864c11..b484e9564ae 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -30,6 +30,7 @@ import zmq import zmq.asyncio from PIL.Image import Image + from sglang.srt.configs.model_config import ModelConfig from sglang.srt.managers.expert_location import ExpertLocationMetadata @@ -521,7 +522,15 @@ def _launch_subprocesses( ) proc = mp.Process( target=run_scheduler_process, - args=(server_args, port_args, expert_location_metadata, gpu_id, tp_rank, None, writer), + args=( + server_args, + port_args, + expert_location_metadata, + gpu_id, + tp_rank, + None, + writer, + ), ) with memory_saver_adapter.configure_subprocess(): proc.start() @@ -570,7 +579,9 @@ def _launch_subprocesses( detoken_proc.start() # Launch tokenizer process - tokenizer_manager = TokenizerManager(server_args, port_args, expert_location_metadata) + tokenizer_manager = TokenizerManager( + server_args, port_args, expert_location_metadata + ) if server_args.chat_template: load_chat_template_for_openai_api( tokenizer_manager, server_args.chat_template, server_args.model_path diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index 0223a8067eb..36b16b8bc43 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -22,6 +22,7 @@ import psutil import setproctitle import zmq + from sglang.srt.layers.dp_attention import compute_dp_attention_world_info from sglang.srt.managers.expert_location import ExpertLocationMetadata from sglang.srt.managers.io_struct import ( @@ -54,8 +55,12 @@ def from_str(cls, method: str): class DataParallelController: """A controller that dispatches requests to multiple data parallel workers.""" - def __init__(self, server_args: ServerArgs, port_args: PortArgs, - expert_location_metadata: ExpertLocationMetadata) -> None: + def __init__( + self, + server_args: ServerArgs, + port_args: PortArgs, + expert_location_metadata: ExpertLocationMetadata, + ) -> None: # Parse args self.max_total_num_tokens = None self.server_args = server_args @@ -84,10 +89,14 @@ def __init__(self, server_args: ServerArgs, port_args: PortArgs, self.workers = [None] * server_args.dp_size if server_args.enable_dp_attention: - dp_port_args = self.launch_dp_attention_schedulers(server_args, port_args, expert_location_metadata) + dp_port_args = self.launch_dp_attention_schedulers( + server_args, port_args, expert_location_metadata + ) self.control_message_step = server_args.tp_size else: - dp_port_args = self.launch_dp_schedulers(server_args, port_args, expert_location_metadata) + dp_port_args = self.launch_dp_schedulers( + server_args, port_args, expert_location_metadata + ) self.control_message_step = 1 # Only node rank 0 runs the real data parallel controller that dispatches the requests. @@ -125,7 +134,14 @@ def launch_dp_schedulers(self, server_args, port_args, expert_location_metadata) # Create a thread for each worker thread = threading.Thread( target=self.launch_tensor_parallel_group_thread, - args=(server_args, tmp_port_args, expert_location_metadata, base_gpu_id, dp_rank, ready_event), + args=( + server_args, + tmp_port_args, + expert_location_metadata, + base_gpu_id, + dp_rank, + ready_event, + ), ) threads.append(thread) base_gpu_id += server_args.tp_size * server_args.gpu_id_step @@ -151,7 +167,9 @@ def launch_tensor_parallel_group_thread( dp_rank: int, ready_event: threading.Event, ): - self.launch_tensor_parallel_group(server_args, port_args, expert_location_metadata, base_gpu_id, dp_rank) + self.launch_tensor_parallel_group( + server_args, port_args, expert_location_metadata, base_gpu_id, dp_rank + ) ready_event.set() # This thread cannot be closed because otherwise the `kill_itself_when_parent_died` @@ -159,8 +177,12 @@ def launch_tensor_parallel_group_thread( while True: pass - def launch_dp_attention_schedulers(self, server_args, port_args, expert_location_metadata): - self.launch_tensor_parallel_group(server_args, port_args, expert_location_metadata, 0, None) + def launch_dp_attention_schedulers( + self, server_args, port_args, expert_location_metadata + ): + self.launch_tensor_parallel_group( + server_args, port_args, expert_location_metadata, 0, None + ) dp_port_args = [] for dp_rank in range(server_args.dp_size): dp_port_args.append(PortArgs.init_new(server_args, dp_rank)) @@ -209,7 +231,15 @@ def launch_tensor_parallel_group( ) proc = mp.Process( target=run_scheduler_process, - args=(server_args, rank_port_args, expert_location_metadata, gpu_id, tp_rank, dp_rank, writer), + args=( + server_args, + rank_port_args, + expert_location_metadata, + gpu_id, + tp_rank, + dp_rank, + writer, + ), ) proc.start() self.scheduler_procs.append(proc) @@ -263,7 +293,9 @@ def run_data_parallel_controller_process( parent_process = psutil.Process().parent() try: - controller = DataParallelController(server_args, port_args, expert_location_metadata) + controller = DataParallelController( + server_args, port_args, expert_location_metadata + ) pipe_writer.send( { "status": "ready", diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index c7e3a61288d..52e8b53e026 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -60,5 +60,6 @@ def clone(self): physical_to_logical_map=self.physical_to_logical_map.clone(), ) + def _create_vanilla_physical_to_logical_map(num_layers: int, num_physical_experts: int): return torch.arange(0, num_physical_experts).repeat(num_layers, 1) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index e4ae286200f..67897939f84 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -32,6 +32,8 @@ import setproctitle import torch import zmq +from torch.distributed import barrier + from sglang.global_config import global_config from sglang.srt.configs.model_config import ModelConfig from sglang.srt.constrained.base_grammar_backend import create_grammar_backend @@ -130,7 +132,6 @@ suppress_other_loggers, ) from sglang.utils import TypeBasedDispatcher, get_exception_traceback -from torch.distributed import barrier logger = logging.getLogger(__name__) @@ -371,8 +372,8 @@ def __init__( 1.0, ) self.new_token_ratio_decay = ( - self.init_new_token_ratio - self.min_new_token_ratio - ) / global_config.default_new_token_ratio_decay_steps + self.init_new_token_ratio - self.min_new_token_ratio + ) / global_config.default_new_token_ratio_decay_steps self.new_token_ratio = self.init_new_token_ratio # Init watchdog thread @@ -1251,10 +1252,10 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: if ( self.lora_paths and len( - lora_set - | set([req.lora_path for req in adder.can_run_list]) - | set([req.lora_path]) - ) + lora_set + | set([req.lora_path for req in adder.can_run_list]) + | set([req.lora_path]) + ) > self.max_loras_per_batch ): self.running_batch.batch_is_full = True @@ -1279,9 +1280,9 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: self.running_batch.batch_is_full = len( adder.can_run_list ) > 0 or ( - self.running_batch is not None - and not self.running_batch.is_empty() - ) + self.running_batch is not None + and not self.running_batch.is_empty() + ) else: self.running_batch.batch_is_full = True break @@ -1513,8 +1514,8 @@ def prepare_dp_attn_batch_raw( # We should have at least 1 token for sample in every case. max(extend_len - logprob_start_len, 1) for logprob_start_len, extend_len in zip( - local_batch.extend_logprob_start_lens, local_batch.extend_lens - ) + local_batch.extend_logprob_start_lens, local_batch.extend_lens + ) ] ) @@ -2039,7 +2040,9 @@ def run_scheduler_process( # Create a scheduler and run the event loop try: - scheduler = Scheduler(server_args, port_args, expert_location_metadata, gpu_id, tp_rank, dp_rank) + scheduler = Scheduler( + server_args, port_args, expert_location_metadata, gpu_id, tp_rank, dp_rank + ) pipe_writer.send( { "status": "ready", diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 5274315a216..93a9507bf19 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -45,6 +45,7 @@ import zmq import zmq.asyncio from fastapi import BackgroundTasks + from sglang.srt.aio_rwlock import RWLock from sglang.srt.configs.model_config import ModelConfig from sglang.srt.disaggregation.conn import KVBootstrapServer @@ -961,8 +962,8 @@ def _handle_batch_output( elif isinstance(recv_obj, BatchTokenIDOut): if self.server_args.stream_output and state.obj.stream: output_token_ids = recv_obj.output_ids[i][ - state.last_output_offset: - ] + state.last_output_offset : + ] state.last_output_offset = len(recv_obj.output_ids[i]) else: output_token_ids = recv_obj.output_ids[i] diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index 3579c986551..ab1e3c53eee 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -62,7 +62,9 @@ def __init__( nccl_port: int, ): # Load the model - self.worker = TpModelWorker(server_args,expert_location_metadata, gpu_id, tp_rank, dp_rank, nccl_port) + self.worker = TpModelWorker( + server_args, expert_location_metadata, gpu_id, tp_rank, dp_rank, nccl_port + ) self.max_running_requests = self.worker.max_running_requests self.device = self.worker.device self.gpu_id = gpu_id diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 0dafbe5ed08..2380031fc73 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -22,6 +22,10 @@ import torch import torch.nn.functional as F +from torch import nn +from tqdm import tqdm +from transformers import PretrainedConfig + from sglang.srt.distributed import ( get_tensor_model_parallel_world_size, parallel_state, @@ -70,9 +74,6 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import DeepEPMode, add_prefix, is_cuda, is_hip -from torch import nn -from tqdm import tqdm -from transformers import PretrainedConfig _is_hip = is_hip() _is_cuda = is_cuda() @@ -410,7 +411,7 @@ def __init__( self.num_heads = num_heads assert num_heads % attn_tp_size == 0 self.num_local_heads = num_heads // attn_tp_size - self.scaling = self.qk_head_dim ** -0.5 + self.scaling = self.qk_head_dim**-0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings @@ -523,12 +524,12 @@ def forward( kv = self.kv_b_proj(kv_a)[0] kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - k_pe = latent_cache[:, :, self.kv_lora_rank:] + k_pe = latent_cache[:, :, self.kv_lora_rank :] q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - q[..., self.qk_nope_head_dim:] = q_pe + q[..., self.qk_nope_head_dim :] = q_pe k = torch.empty_like(q) k[..., : self.qk_nope_head_dim] = k_nope - k[..., self.qk_nope_head_dim:] = k_pe + k[..., self.qk_nope_head_dim :] = k_pe q = torch.nn.functional.pad(q, [0, 256 - self.qk_head_dim], value=0).view( -1, self.num_local_heads * 256 ) @@ -540,8 +541,8 @@ def forward( ) attn_output = self.attn(q, k, v, forward_batch) attn_output = attn_output.view(-1, self.num_local_heads, 256)[ - ..., : self.v_head_dim - ].reshape(-1, self.num_local_heads * self.v_head_dim) + ..., : self.v_head_dim + ].reshape(-1, self.num_local_heads * self.v_head_dim) output, _ = self.o_proj(attn_output) return output @@ -582,7 +583,7 @@ def __init__( self.num_heads = num_heads assert num_heads % attn_tp_size == 0 self.num_local_heads = num_heads // attn_tp_size - self.scaling = self.qk_head_dim ** -0.5 + self.scaling = self.qk_head_dim**-0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings @@ -769,16 +770,16 @@ def forward_normal( kv = self.kv_b_proj(kv_a)[0] kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) k_nope = kv[..., : self.qk_nope_head_dim] - v = kv[..., self.qk_nope_head_dim:] - k_pe = latent_cache[:, :, self.kv_lora_rank:] + v = kv[..., self.qk_nope_head_dim :] + k_pe = latent_cache[:, :, self.kv_lora_rank :] q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - q[..., self.qk_nope_head_dim:] = q_pe + q[..., self.qk_nope_head_dim :] = q_pe k = torch.empty_like(q) k[..., : self.qk_nope_head_dim] = k_nope - k[..., self.qk_nope_head_dim:] = k_pe + k[..., self.qk_nope_head_dim :] = k_pe latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1) - latent_cache[:, :, self.kv_lora_rank:] = k_pe + latent_cache[:, :, self.kv_lora_rank :] = k_pe # Save latent cache forward_batch.token_to_kv_pool.set_kv_buffer( @@ -831,11 +832,11 @@ def forward_absorb( v_input = self.kv_a_layernorm(v_input.contiguous()).unsqueeze(1) k_input = latent_cache.unsqueeze(1) k_input[..., : self.kv_lora_rank] = v_input - k_pe = k_input[..., self.kv_lora_rank:] + k_pe = k_input[..., self.kv_lora_rank :] q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - q_input[..., self.kv_lora_rank:] = q_pe - k_input[..., self.kv_lora_rank:] = k_pe + q_input[..., self.kv_lora_rank :] = q_pe + k_input[..., self.kv_lora_rank :] = k_pe attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch) attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank) @@ -911,15 +912,15 @@ def forward_absorb_fused_mla_rope( k_input[..., : self.kv_lora_rank] = v_input if not enable_rope_fusion: - k_pe = k_input[..., self.kv_lora_rank:] + k_pe = k_input[..., self.kv_lora_rank :] q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - q_input[..., self.kv_lora_rank:] = q_pe - k_input[..., self.kv_lora_rank:] = k_pe + q_input[..., self.kv_lora_rank :] = q_pe + k_input[..., self.kv_lora_rank :] = k_pe k_pe_output = None else: - k_pe_output = torch.empty_like(k_input[..., self.kv_lora_rank:]) + k_pe_output = torch.empty_like(k_input[..., self.kv_lora_rank :]) - q_input[..., self.kv_lora_rank:] = q_pe + q_input[..., self.kv_lora_rank :] = q_pe # attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch) # Use Fused ROPE with use_rope=OFF. @@ -976,7 +977,7 @@ def forward_absorb_fused_mla_rope( ) if enable_rope_fusion: - k_input[..., self.kv_lora_rank:] = k_pe_output + k_input[..., self.kv_lora_rank :] = k_pe_output forward_batch.token_to_kv_pool.set_kv_buffer( self.attn_mqa, forward_batch.out_cache_loc, k_input, None ) @@ -1494,7 +1495,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): self.config.moe_layer_freq, ), desc=f"Cloning {self.n_share_experts_fusion} " - "replicas of the shared expert into MoE", + "replicas of the shared expert into MoE", ): for num_repeat in range(self.n_share_experts_fusion): for suffix in suffix_list: @@ -1525,11 +1526,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", num_experts=self.config.n_routed_experts - + ( - self.n_share_experts_fusion - if self.n_share_experts_fusion is not None - else 0 - ), + + ( + self.n_share_experts_fusion + if self.n_share_experts_fusion is not None + else 0 + ), ) params_dict = dict(self.named_parameters()) From d542934186fd785b57169d7029016bd8f07632c1 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 15:44:43 +0800 Subject: [PATCH 0265/1089] more --- python/sglang/srt/managers/expert_location.py | 4 +- python/sglang/srt/managers/schedule_batch.py | 39 ++++++++++--------- .../sglang/srt/model_executor/model_runner.py | 2 +- 3 files changed, 24 insertions(+), 21 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 8e3a3c4f18e..bf5fae6b6a9 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -20,7 +20,7 @@ def from_model_config(model_config: ModelConfig): model_class, _ = get_model_architecture(model_config) if hasattr(model_class, "get_expert_location_metadata"): return model_class.get_expert_location_metadata(model_config.hf_config) - return ExpertLocationMetadata._init_dummy() + return ExpertLocationMetadata.init_dummy() @staticmethod def init_new(num_layers: int, num_logical_experts: int): @@ -41,7 +41,7 @@ def init_new(num_layers: int, num_logical_experts: int): ) @staticmethod - def _init_dummy(): + def init_dummy(): return ExpertLocationMetadata.init_new(num_layers=1, num_logical_experts=1) def local_physical_to_global_physical( diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index cce17729e50..3a5f51f571b 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -3,6 +3,8 @@ import hashlib from enum import Enum, auto +from sglang.srt.managers.expert_location import ExpertLocationMetadata + # Copyright 2023-2024 SGLang Team # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -84,6 +86,7 @@ "n_share_experts_fusion": ServerArgs.n_share_experts_fusion, "disable_shared_experts_fusion": ServerArgs.disable_shared_experts_fusion, } +global_expert_location_metadata = ExpertLocationMetadata.init_dummy() logger = logging.getLogger(__name__) @@ -605,7 +608,7 @@ def init_incremental_detokenize(self): ) all_ids = self.origin_input_ids_unpadded + self.output_ids - return all_ids[self.surr_offset :], self.read_offset - self.surr_offset + return all_ids[self.surr_offset:], self.read_offset - self.surr_offset def check_finished(self): if self.finished(): @@ -646,7 +649,7 @@ def check_finished(self): # Check stop strings if len(self.sampling_params.stop_strs) > 0: tail_str = self.tokenizer.decode( - self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :] + self.output_ids[-(self.sampling_params.stop_str_max_len + 1):] ) for stop_str in self.sampling_params.stop_strs: @@ -942,15 +945,15 @@ def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]) # NOTE: the encoder part should be considered as a whole assert len(req.prefix_indices) == 0 input_ids[i] = input_ids[i][encoder_len:] - encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len]) + encoder_out_cache_loc.append(self.out_cache_loc[pt: pt + encoder_len]) decoder_out_cache_loc.append( - self.out_cache_loc[pt + encoder_len : pt + req.extend_input_len] + self.out_cache_loc[pt + encoder_len: pt + req.extend_input_len] ) self.extend_lens[i] -= encoder_len self.extend_num_tokens -= encoder_len else: decoder_out_cache_loc.append( - self.out_cache_loc[pt : pt + req.extend_input_len] + self.out_cache_loc[pt: pt + req.extend_input_len] ) self.prefix_lens[i] -= encoder_len @@ -989,7 +992,7 @@ def prepare_for_extend(self): # Init tensors reqs = self.reqs - input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs] + input_ids = [r.fill_ids[len(r.prefix_indices):] for r in reqs] extend_num_tokens = sum(len(ids) for ids in input_ids) seq_lens = [len(r.fill_ids) for r in reqs] prefix_lens = [len(r.prefix_indices) for r in reqs] @@ -1064,8 +1067,8 @@ def prepare_for_extend(self): global_start_idx = req.logprob_start_len logprob_token_ids = req.origin_input_ids[ - global_start_idx + 1 : global_end_idx + 1 - ] + global_start_idx + 1: global_end_idx + 1 + ] extend_input_logprob_token_ids.extend(logprob_token_ids) # We will need req.extend_input_len - req.extend_logprob_start_len number of @@ -1139,7 +1142,7 @@ def prepare_for_extend(self): for i in range(bs): self.req_to_token_pool.write( (req_pool_indices[i], slice(prefix_lens[i], seq_lens[i])), - out_cache_loc[pt : pt + extend_lens[i]], + out_cache_loc[pt: pt + extend_lens[i]], ) pt += extend_lens[i] @@ -1256,18 +1259,18 @@ def get_required_tokens(num_reqs: int): if isinstance(self.tree_cache, ChunkCache): # ChunkCache does not have eviction token_indices = self.req_to_token_pool.req_to_token[ - req.req_pool_idx, : seq_lens_cpu[idx] - ] + req.req_pool_idx, : seq_lens_cpu[idx] + ] self.token_to_kv_pool_allocator.free(token_indices) self.req_to_token_pool.free(req.req_pool_idx) else: # TODO: apply more fine-grained retraction last_uncached_pos = ( - len(req.prefix_indices) // server_args.page_size - ) * server_args.page_size + len(req.prefix_indices) // server_args.page_size + ) * server_args.page_size token_indices = self.req_to_token_pool.req_to_token[ - req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx] - ] + req.req_pool_idx, last_uncached_pos: seq_lens_cpu[idx] + ] self.token_to_kv_pool_allocator.free(token_indices) self.req_to_token_pool.free(req.req_pool_idx) @@ -1291,8 +1294,8 @@ def get_required_tokens(num_reqs: int): total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs) new_estimate_ratio = ( - total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs) - ) / total_max_new_tokens + total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs) + ) / total_max_new_tokens new_estimate_ratio = min(1.0, new_estimate_ratio) return retracted_reqs, new_estimate_ratio @@ -1389,7 +1392,7 @@ def filter_batch( i for i in range(len(self.reqs)) if not self.reqs[i].finished() - and self.reqs[i] is not chunked_req_to_exclude + and self.reqs[i] is not chunked_req_to_exclude ] if keep_indices is None or len(keep_indices) == 0: diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 36b691eb4ed..92410ffbcca 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -196,7 +196,7 @@ def initialize(self, min_per_gpu_memory: float): expert_distribution_recorder.initialize( server_args, - self.expert_location_metadata, + global_expert_location_metadata, # TODO handle DP!=TP case rank=self.tp_rank, ) From ad5a7d50a85af9b4395c5eed035a2153e589d77e Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 15:45:30 +0800 Subject: [PATCH 0266/1089] more --- python/sglang/srt/managers/expert_location.py | 4 ++-- python/sglang/srt/managers/schedule_batch.py | 2 +- python/sglang/srt/model_executor/model_runner.py | 7 ++++--- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index bf5fae6b6a9..8e3a3c4f18e 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -20,7 +20,7 @@ def from_model_config(model_config: ModelConfig): model_class, _ = get_model_architecture(model_config) if hasattr(model_class, "get_expert_location_metadata"): return model_class.get_expert_location_metadata(model_config.hf_config) - return ExpertLocationMetadata.init_dummy() + return ExpertLocationMetadata._init_dummy() @staticmethod def init_new(num_layers: int, num_logical_experts: int): @@ -41,7 +41,7 @@ def init_new(num_layers: int, num_logical_experts: int): ) @staticmethod - def init_dummy(): + def _init_dummy(): return ExpertLocationMetadata.init_new(num_layers=1, num_logical_experts=1) def local_physical_to_global_physical( diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 3a5f51f571b..eabdfae6822 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -86,7 +86,7 @@ "n_share_experts_fusion": ServerArgs.n_share_experts_fusion, "disable_shared_experts_fusion": ServerArgs.disable_shared_experts_fusion, } -global_expert_location_metadata = ExpertLocationMetadata.init_dummy() +global_expert_location_metadata: Optional[ExpertLocationMetadata] = None logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 92410ffbcca..7b83f680f24 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -47,7 +47,7 @@ from sglang.srt.lora.lora_manager import LoRAManager from sglang.srt.managers.expert_distribution import expert_distribution_recorder from sglang.srt.managers.expert_location import ExpertLocationMetadata -from sglang.srt.managers.schedule_batch import global_server_args_dict +from sglang.srt.managers.schedule_batch import global_server_args_dict, global_expert_location_metadata from sglang.srt.mem_cache.memory_pool import ( DoubleSparseTokenToKVPool, MHATokenToKVPool, @@ -173,14 +173,15 @@ def __init__( } ) + global global_expert_location_metadata + global_expert_location_metadata = expert_location_metadata + # CPU offload set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3)) # Get memory before model loading min_per_gpu_memory = self.init_torch_distributed() - self.expert_location_metadata = expert_location_metadata.clone() - # If it is a draft model tp_group can be different. self.initialize(min_per_gpu_memory) From fb9b8b1ac22844cf53c435ee1b2582012abb71ca Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 15:46:46 +0800 Subject: [PATCH 0267/1089] Revert "more" This reverts commit ad5a7d50a85af9b4395c5eed035a2153e589d77e. --- python/sglang/srt/managers/expert_location.py | 4 ++-- python/sglang/srt/managers/schedule_batch.py | 2 +- python/sglang/srt/model_executor/model_runner.py | 7 +++---- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 8e3a3c4f18e..bf5fae6b6a9 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -20,7 +20,7 @@ def from_model_config(model_config: ModelConfig): model_class, _ = get_model_architecture(model_config) if hasattr(model_class, "get_expert_location_metadata"): return model_class.get_expert_location_metadata(model_config.hf_config) - return ExpertLocationMetadata._init_dummy() + return ExpertLocationMetadata.init_dummy() @staticmethod def init_new(num_layers: int, num_logical_experts: int): @@ -41,7 +41,7 @@ def init_new(num_layers: int, num_logical_experts: int): ) @staticmethod - def _init_dummy(): + def init_dummy(): return ExpertLocationMetadata.init_new(num_layers=1, num_logical_experts=1) def local_physical_to_global_physical( diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index eabdfae6822..3a5f51f571b 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -86,7 +86,7 @@ "n_share_experts_fusion": ServerArgs.n_share_experts_fusion, "disable_shared_experts_fusion": ServerArgs.disable_shared_experts_fusion, } -global_expert_location_metadata: Optional[ExpertLocationMetadata] = None +global_expert_location_metadata = ExpertLocationMetadata.init_dummy() logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 7b83f680f24..92410ffbcca 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -47,7 +47,7 @@ from sglang.srt.lora.lora_manager import LoRAManager from sglang.srt.managers.expert_distribution import expert_distribution_recorder from sglang.srt.managers.expert_location import ExpertLocationMetadata -from sglang.srt.managers.schedule_batch import global_server_args_dict, global_expert_location_metadata +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.mem_cache.memory_pool import ( DoubleSparseTokenToKVPool, MHATokenToKVPool, @@ -173,15 +173,14 @@ def __init__( } ) - global global_expert_location_metadata - global_expert_location_metadata = expert_location_metadata - # CPU offload set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3)) # Get memory before model loading min_per_gpu_memory = self.init_torch_distributed() + self.expert_location_metadata = expert_location_metadata.clone() + # If it is a draft model tp_group can be different. self.initialize(min_per_gpu_memory) From f0ec4fd6a9b220da1cc93054ff36669c80c26db3 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 15:47:44 +0800 Subject: [PATCH 0268/1089] fmt --- python/sglang/srt/managers/schedule_batch.py | 36 ++++++++++---------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 3a5f51f571b..c78f4ee29d0 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -608,7 +608,7 @@ def init_incremental_detokenize(self): ) all_ids = self.origin_input_ids_unpadded + self.output_ids - return all_ids[self.surr_offset:], self.read_offset - self.surr_offset + return all_ids[self.surr_offset :], self.read_offset - self.surr_offset def check_finished(self): if self.finished(): @@ -649,7 +649,7 @@ def check_finished(self): # Check stop strings if len(self.sampling_params.stop_strs) > 0: tail_str = self.tokenizer.decode( - self.output_ids[-(self.sampling_params.stop_str_max_len + 1):] + self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :] ) for stop_str in self.sampling_params.stop_strs: @@ -945,15 +945,15 @@ def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]) # NOTE: the encoder part should be considered as a whole assert len(req.prefix_indices) == 0 input_ids[i] = input_ids[i][encoder_len:] - encoder_out_cache_loc.append(self.out_cache_loc[pt: pt + encoder_len]) + encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len]) decoder_out_cache_loc.append( - self.out_cache_loc[pt + encoder_len: pt + req.extend_input_len] + self.out_cache_loc[pt + encoder_len : pt + req.extend_input_len] ) self.extend_lens[i] -= encoder_len self.extend_num_tokens -= encoder_len else: decoder_out_cache_loc.append( - self.out_cache_loc[pt: pt + req.extend_input_len] + self.out_cache_loc[pt : pt + req.extend_input_len] ) self.prefix_lens[i] -= encoder_len @@ -992,7 +992,7 @@ def prepare_for_extend(self): # Init tensors reqs = self.reqs - input_ids = [r.fill_ids[len(r.prefix_indices):] for r in reqs] + input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs] extend_num_tokens = sum(len(ids) for ids in input_ids) seq_lens = [len(r.fill_ids) for r in reqs] prefix_lens = [len(r.prefix_indices) for r in reqs] @@ -1067,8 +1067,8 @@ def prepare_for_extend(self): global_start_idx = req.logprob_start_len logprob_token_ids = req.origin_input_ids[ - global_start_idx + 1: global_end_idx + 1 - ] + global_start_idx + 1 : global_end_idx + 1 + ] extend_input_logprob_token_ids.extend(logprob_token_ids) # We will need req.extend_input_len - req.extend_logprob_start_len number of @@ -1142,7 +1142,7 @@ def prepare_for_extend(self): for i in range(bs): self.req_to_token_pool.write( (req_pool_indices[i], slice(prefix_lens[i], seq_lens[i])), - out_cache_loc[pt: pt + extend_lens[i]], + out_cache_loc[pt : pt + extend_lens[i]], ) pt += extend_lens[i] @@ -1259,18 +1259,18 @@ def get_required_tokens(num_reqs: int): if isinstance(self.tree_cache, ChunkCache): # ChunkCache does not have eviction token_indices = self.req_to_token_pool.req_to_token[ - req.req_pool_idx, : seq_lens_cpu[idx] - ] + req.req_pool_idx, : seq_lens_cpu[idx] + ] self.token_to_kv_pool_allocator.free(token_indices) self.req_to_token_pool.free(req.req_pool_idx) else: # TODO: apply more fine-grained retraction last_uncached_pos = ( - len(req.prefix_indices) // server_args.page_size - ) * server_args.page_size + len(req.prefix_indices) // server_args.page_size + ) * server_args.page_size token_indices = self.req_to_token_pool.req_to_token[ - req.req_pool_idx, last_uncached_pos: seq_lens_cpu[idx] - ] + req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx] + ] self.token_to_kv_pool_allocator.free(token_indices) self.req_to_token_pool.free(req.req_pool_idx) @@ -1294,8 +1294,8 @@ def get_required_tokens(num_reqs: int): total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs) new_estimate_ratio = ( - total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs) - ) / total_max_new_tokens + total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs) + ) / total_max_new_tokens new_estimate_ratio = min(1.0, new_estimate_ratio) return retracted_reqs, new_estimate_ratio @@ -1392,7 +1392,7 @@ def filter_batch( i for i in range(len(self.reqs)) if not self.reqs[i].finished() - and self.reqs[i] is not chunked_req_to_exclude + and self.reqs[i] is not chunked_req_to_exclude ] if keep_indices is None or len(keep_indices) == 0: From e33bbf2e77bd3c265498e02553bde66ed0bd6f6a Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 15:51:30 +0800 Subject: [PATCH 0269/1089] more --- python/sglang/srt/managers/schedule_batch.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index cce17729e50..f3a27f0a446 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -3,6 +3,8 @@ import hashlib from enum import Enum, auto +from sglang.srt.managers.expert_location import ExpertLocationMetadata + # Copyright 2023-2024 SGLang Team # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -85,6 +87,8 @@ "disable_shared_experts_fusion": ServerArgs.disable_shared_experts_fusion, } +global_expert_location_metadata = ExpertLocationMetadata.init_empty() + logger = logging.getLogger(__name__) From 8ee30d36663fd09bb744261367e82db585e07247 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 15:51:40 +0800 Subject: [PATCH 0270/1089] more --- python/sglang/srt/managers/expert_location.py | 4 +-- python/sglang/srt/managers/schedule_batch.py | 36 +++++++++---------- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 52e8b53e026..aecc1611370 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -20,7 +20,7 @@ def from_model_config(model_config: ModelConfig): model_class, _ = get_model_architecture(model_config) if hasattr(model_class, "get_expert_location_metadata"): return model_class.get_expert_location_metadata(model_config.hf_config) - return ExpertLocationMetadata._init_dummy() + return ExpertLocationMetadata.init_empty() @staticmethod def init_new(num_layers: int, num_logical_experts: int): @@ -41,7 +41,7 @@ def init_new(num_layers: int, num_logical_experts: int): ) @staticmethod - def _init_dummy(): + def init_empty(): return ExpertLocationMetadata.init_new(num_layers=1, num_logical_experts=1) def local_physical_to_global_physical( diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index f3a27f0a446..a705102a9b8 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -609,7 +609,7 @@ def init_incremental_detokenize(self): ) all_ids = self.origin_input_ids_unpadded + self.output_ids - return all_ids[self.surr_offset :], self.read_offset - self.surr_offset + return all_ids[self.surr_offset:], self.read_offset - self.surr_offset def check_finished(self): if self.finished(): @@ -650,7 +650,7 @@ def check_finished(self): # Check stop strings if len(self.sampling_params.stop_strs) > 0: tail_str = self.tokenizer.decode( - self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :] + self.output_ids[-(self.sampling_params.stop_str_max_len + 1):] ) for stop_str in self.sampling_params.stop_strs: @@ -946,15 +946,15 @@ def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]) # NOTE: the encoder part should be considered as a whole assert len(req.prefix_indices) == 0 input_ids[i] = input_ids[i][encoder_len:] - encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len]) + encoder_out_cache_loc.append(self.out_cache_loc[pt: pt + encoder_len]) decoder_out_cache_loc.append( - self.out_cache_loc[pt + encoder_len : pt + req.extend_input_len] + self.out_cache_loc[pt + encoder_len: pt + req.extend_input_len] ) self.extend_lens[i] -= encoder_len self.extend_num_tokens -= encoder_len else: decoder_out_cache_loc.append( - self.out_cache_loc[pt : pt + req.extend_input_len] + self.out_cache_loc[pt: pt + req.extend_input_len] ) self.prefix_lens[i] -= encoder_len @@ -993,7 +993,7 @@ def prepare_for_extend(self): # Init tensors reqs = self.reqs - input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs] + input_ids = [r.fill_ids[len(r.prefix_indices):] for r in reqs] extend_num_tokens = sum(len(ids) for ids in input_ids) seq_lens = [len(r.fill_ids) for r in reqs] prefix_lens = [len(r.prefix_indices) for r in reqs] @@ -1068,8 +1068,8 @@ def prepare_for_extend(self): global_start_idx = req.logprob_start_len logprob_token_ids = req.origin_input_ids[ - global_start_idx + 1 : global_end_idx + 1 - ] + global_start_idx + 1: global_end_idx + 1 + ] extend_input_logprob_token_ids.extend(logprob_token_ids) # We will need req.extend_input_len - req.extend_logprob_start_len number of @@ -1143,7 +1143,7 @@ def prepare_for_extend(self): for i in range(bs): self.req_to_token_pool.write( (req_pool_indices[i], slice(prefix_lens[i], seq_lens[i])), - out_cache_loc[pt : pt + extend_lens[i]], + out_cache_loc[pt: pt + extend_lens[i]], ) pt += extend_lens[i] @@ -1260,18 +1260,18 @@ def get_required_tokens(num_reqs: int): if isinstance(self.tree_cache, ChunkCache): # ChunkCache does not have eviction token_indices = self.req_to_token_pool.req_to_token[ - req.req_pool_idx, : seq_lens_cpu[idx] - ] + req.req_pool_idx, : seq_lens_cpu[idx] + ] self.token_to_kv_pool_allocator.free(token_indices) self.req_to_token_pool.free(req.req_pool_idx) else: # TODO: apply more fine-grained retraction last_uncached_pos = ( - len(req.prefix_indices) // server_args.page_size - ) * server_args.page_size + len(req.prefix_indices) // server_args.page_size + ) * server_args.page_size token_indices = self.req_to_token_pool.req_to_token[ - req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx] - ] + req.req_pool_idx, last_uncached_pos: seq_lens_cpu[idx] + ] self.token_to_kv_pool_allocator.free(token_indices) self.req_to_token_pool.free(req.req_pool_idx) @@ -1295,8 +1295,8 @@ def get_required_tokens(num_reqs: int): total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs) new_estimate_ratio = ( - total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs) - ) / total_max_new_tokens + total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs) + ) / total_max_new_tokens new_estimate_ratio = min(1.0, new_estimate_ratio) return retracted_reqs, new_estimate_ratio @@ -1393,7 +1393,7 @@ def filter_batch( i for i in range(len(self.reqs)) if not self.reqs[i].finished() - and self.reqs[i] is not chunked_req_to_exclude + and self.reqs[i] is not chunked_req_to_exclude ] if keep_indices is None or len(keep_indices) == 0: From af9291f67544aedd2ae0932efed4cb2b5ae1ad00 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 15:51:59 +0800 Subject: [PATCH 0271/1089] more --- python/sglang/srt/model_executor/model_runner.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 36b691eb4ed..a1504735913 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -47,7 +47,7 @@ from sglang.srt.lora.lora_manager import LoRAManager from sglang.srt.managers.expert_distribution import expert_distribution_recorder from sglang.srt.managers.expert_location import ExpertLocationMetadata -from sglang.srt.managers.schedule_batch import global_server_args_dict +from sglang.srt.managers.schedule_batch import global_server_args_dict, global_expert_location_metadata from sglang.srt.mem_cache.memory_pool import ( DoubleSparseTokenToKVPool, MHATokenToKVPool, @@ -172,6 +172,7 @@ def __init__( "use_mla_backend": self.use_mla_backend, } ) + global_expert_location_metadata.update(expert_location_metadata) # CPU offload set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3)) @@ -179,8 +180,6 @@ def __init__( # Get memory before model loading min_per_gpu_memory = self.init_torch_distributed() - self.expert_location_metadata = expert_location_metadata.clone() - # If it is a draft model tp_group can be different. self.initialize(min_per_gpu_memory) @@ -196,7 +195,7 @@ def initialize(self, min_per_gpu_memory: float): expert_distribution_recorder.initialize( server_args, - self.expert_location_metadata, + global_expert_location_metadata, # TODO handle DP!=TP case rank=self.tp_rank, ) From 4e6803c9b5b7d6208bf4f96ccd55f90f47ddc3ab Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 15:52:40 +0800 Subject: [PATCH 0272/1089] more --- python/sglang/srt/managers/expert_location.py | 12 +++++------- python/sglang/srt/model_executor/model_runner.py | 5 ++--- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index aecc1611370..4a66909b3b7 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -52,13 +52,11 @@ def local_physical_to_global_physical( def global_physical_to_local_physical(self, global_physical_expert_index: int): return global_physical_expert_index % self.num_local_physical_experts - def clone(self): - return ExpertLocationMetadata( - num_layers=self.num_layers, - num_local_physical_experts=self.num_local_physical_experts, - num_logical_experts=self.num_logical_experts, - physical_to_logical_map=self.physical_to_logical_map.clone(), - ) + def update(self, other: "ExpertLocationMetadata"): + if self_is_empty: + TODO + else: + pass # will handle later def _create_vanilla_physical_to_logical_map(num_layers: int, num_physical_experts: int): diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index a1504735913..a473d8b9905 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -24,7 +24,6 @@ import torch import torch.distributed as dist - from sglang.srt.configs.device_config import DeviceConfig from sglang.srt.configs.load_config import LoadConfig from sglang.srt.configs.model_config import AttentionArch, ModelConfig @@ -175,7 +174,7 @@ def __init__( global_expert_location_metadata.update(expert_location_metadata) # CPU offload - set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3)) + set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024 ** 3)) # Get memory before model loading min_per_gpu_memory = self.init_torch_distributed() @@ -922,7 +921,7 @@ def init_double_sparsity_channel_config(self, selected_channel): key = "model.layers." + str(i) + ".self_attn" + selected_channel self.sorted_channels.append( torch.tensor(channel_config[key])[ - :, : self.server_args.ds_heavy_channel_num + :, : self.server_args.ds_heavy_channel_num ] .contiguous() .cuda() From c8db6f86e6d3211d8dfa1bb5c5f99890b14409e1 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 15:53:17 +0800 Subject: [PATCH 0273/1089] more --- python/sglang/srt/managers/expert_location.py | 17 +++++++++-------- python/sglang/srt/managers/schedule_batch.py | 2 +- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 4a66909b3b7..39fdee18a05 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -1,7 +1,6 @@ from dataclasses import dataclass import torch - from sglang.srt.configs.model_config import ModelConfig from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.model_loader import get_model_architecture @@ -9,6 +8,7 @@ @dataclass class ExpertLocationMetadata: + is_dummy: bool num_layers: int num_local_physical_experts: int num_logical_experts: int @@ -20,17 +20,18 @@ def from_model_config(model_config: ModelConfig): model_class, _ = get_model_architecture(model_config) if hasattr(model_class, "get_expert_location_metadata"): return model_class.get_expert_location_metadata(model_config.hf_config) - return ExpertLocationMetadata.init_empty() + return ExpertLocationMetadata.init_dummy() @staticmethod - def init_new(num_layers: int, num_logical_experts: int): + def init_new(num_layers: int, num_logical_experts: int, is_dummy: bool = False): # TODO handle more complex cases like duplicating experts on different GPUs num_local_physical_experts = ( - num_logical_experts // get_tensor_model_parallel_world_size() + num_logical_experts // get_tensor_model_parallel_world_size() ) num_physical_experts = num_logical_experts return ExpertLocationMetadata( + is_dummy=is_dummy, num_layers=num_layers, num_logical_experts=num_logical_experts, num_local_physical_experts=num_local_physical_experts, @@ -41,11 +42,11 @@ def init_new(num_layers: int, num_logical_experts: int): ) @staticmethod - def init_empty(): - return ExpertLocationMetadata.init_new(num_layers=1, num_logical_experts=1) + def init_dummy(): + return ExpertLocationMetadata.init_new(num_layers=1, num_logical_experts=1, is_dummy=True) def local_physical_to_global_physical( - self, rank: int, local_physical_expert_index: int + self, rank: int, local_physical_expert_index: int ): return self.num_local_physical_experts * rank + local_physical_expert_index @@ -56,7 +57,7 @@ def update(self, other: "ExpertLocationMetadata"): if self_is_empty: TODO else: - pass # will handle later + pass # will handle later def _create_vanilla_physical_to_logical_map(num_layers: int, num_physical_experts: int): diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index a705102a9b8..dae80d54f2c 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -87,7 +87,7 @@ "disable_shared_experts_fusion": ServerArgs.disable_shared_experts_fusion, } -global_expert_location_metadata = ExpertLocationMetadata.init_empty() +global_expert_location_metadata = ExpertLocationMetadata.init_dummy() logger = logging.getLogger(__name__) From b44d974c6c91d471c3427115bdbbd78a056c8281 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 15:54:02 +0800 Subject: [PATCH 0274/1089] more --- python/sglang/srt/managers/expert_location.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 39fdee18a05..9f68d9813e2 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -54,10 +54,13 @@ def global_physical_to_local_physical(self, global_physical_expert_index: int): return global_physical_expert_index % self.num_local_physical_experts def update(self, other: "ExpertLocationMetadata"): - if self_is_empty: - TODO + if self.is_dummy: + self.num_layers = other.num_layers + self.num_local_physical_experts = other.num_local_physical_experts + self.num_logical_experts = other.num_logical_experts + self.physical_to_logical_map = other.physical_to_logical_map.detach().clone() else: - pass # will handle later + raise NotImplementedError # will handle later def _create_vanilla_physical_to_logical_map(num_layers: int, num_physical_experts: int): From 8dc939f1bcc5fa2bbe2fa107423682c57396044b Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 15:54:36 +0800 Subject: [PATCH 0275/1089] more --- python/sglang/srt/managers/expert_location.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 9f68d9813e2..cea158595a5 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -55,13 +55,17 @@ def global_physical_to_local_physical(self, global_physical_expert_index: int): def update(self, other: "ExpertLocationMetadata"): if self.is_dummy: - self.num_layers = other.num_layers - self.num_local_physical_experts = other.num_local_physical_experts - self.num_logical_experts = other.num_logical_experts - self.physical_to_logical_map = other.physical_to_logical_map.detach().clone() + self._update_by_assign(other) else: raise NotImplementedError # will handle later + def _update_by_assign(self, other: "ExpertLocationMetadata"): + self.is_dummy = other.is_dummy + self.num_layers = other.num_layers + self.num_local_physical_experts = other.num_local_physical_experts + self.num_logical_experts = other.num_logical_experts + self.physical_to_logical_map = other.physical_to_logical_map.detach().clone() + def _create_vanilla_physical_to_logical_map(num_layers: int, num_physical_experts: int): return torch.arange(0, num_physical_experts).repeat(num_layers, 1) From a12276c1657980d2414f9637f3fc29d1be14c216 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 15:55:23 +0800 Subject: [PATCH 0276/1089] more --- python/sglang/srt/managers/expert_location.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index cea158595a5..0b1764f597c 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -55,11 +55,11 @@ def global_physical_to_local_physical(self, global_physical_expert_index: int): def update(self, other: "ExpertLocationMetadata"): if self.is_dummy: - self._update_by_assign(other) + self._update_unconditionally(other) else: raise NotImplementedError # will handle later - def _update_by_assign(self, other: "ExpertLocationMetadata"): + def _update_unconditionally(self, other: "ExpertLocationMetadata"): self.is_dummy = other.is_dummy self.num_layers = other.num_layers self.num_local_physical_experts = other.num_local_physical_experts From 39222518835c83a5f24d9cfee0496fb035fa072b Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 15:56:58 +0800 Subject: [PATCH 0277/1089] more --- python/sglang/srt/managers/expert_location.py | 24 +++++++++++++++---- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 0b1764f597c..3f5830e78b5 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -57,15 +57,29 @@ def update(self, other: "ExpertLocationMetadata"): if self.is_dummy: self._update_unconditionally(other) else: - raise NotImplementedError # will handle later + self._update_partial(other) def _update_unconditionally(self, other: "ExpertLocationMetadata"): - self.is_dummy = other.is_dummy - self.num_layers = other.num_layers - self.num_local_physical_experts = other.num_local_physical_experts - self.num_logical_experts = other.num_logical_experts + for field in _TRIVIAL_UPDATE_FIELDS: + setattr(self, field, getattr(other, field)) + self.physical_to_logical_map = other.physical_to_logical_map.detach().clone() + def _update_partial(self, other: "ExpertLocationMetadata"): + for field in _TRIVIAL_UPDATE_FIELDS: + assert getattr(self, field) == getattr(other, field) + + # Cannot update address to avoid breaking CUDA graph + self.physical_to_logical_map[...] = other.physical_to_logical_map + + +_TRIVIAL_UPDATE_FIELDS = [ + "is_dummy", + "num_layers", + "num_local_physical_experts", + "num_logical_experts", +] + def _create_vanilla_physical_to_logical_map(num_layers: int, num_physical_experts: int): return torch.arange(0, num_physical_experts).repeat(num_layers, 1) From 4a82f7a2b2a1ecf317a5618a13aa54e7f57affa8 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 15:57:54 +0800 Subject: [PATCH 0278/1089] more --- python/sglang/srt/managers/expert_location.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 3f5830e78b5..29159bf85a5 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -60,25 +60,28 @@ def update(self, other: "ExpertLocationMetadata"): self._update_partial(other) def _update_unconditionally(self, other: "ExpertLocationMetadata"): - for field in _TRIVIAL_UPDATE_FIELDS: + for field in _UPDATE_FIELDS_TRIVIAL: setattr(self, field, getattr(other, field)) - - self.physical_to_logical_map = other.physical_to_logical_map.detach().clone() + for field in _UPDATE_FIELDS_TENSOR: + setattr(self, field, getattr(other, field).detach().clone()) def _update_partial(self, other: "ExpertLocationMetadata"): - for field in _TRIVIAL_UPDATE_FIELDS: + for field in _UPDATE_FIELDS_TRIVIAL: assert getattr(self, field) == getattr(other, field) - - # Cannot update address to avoid breaking CUDA graph - self.physical_to_logical_map[...] = other.physical_to_logical_map + for field in _UPDATE_FIELDS_TENSOR: + # Cannot update address to avoid breaking CUDA graph + getattr(self, field)[...] = getattr(other, field) -_TRIVIAL_UPDATE_FIELDS = [ +_UPDATE_FIELDS_TRIVIAL = [ "is_dummy", "num_layers", "num_local_physical_experts", "num_logical_experts", ] +_UPDATE_FIELDS_TENSOR = [ + "physical_to_logical_map", +] def _create_vanilla_physical_to_logical_map(num_layers: int, num_physical_experts: int): From fba78f3d180ecc29fec5f5c45b3e847fee0f6aac Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 16:00:31 +0800 Subject: [PATCH 0279/1089] Revert "fmt" This reverts commit f0ec4fd6a9b220da1cc93054ff36669c80c26db3. --- python/sglang/srt/managers/schedule_batch.py | 36 ++++++++++---------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index c78f4ee29d0..3a5f51f571b 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -608,7 +608,7 @@ def init_incremental_detokenize(self): ) all_ids = self.origin_input_ids_unpadded + self.output_ids - return all_ids[self.surr_offset :], self.read_offset - self.surr_offset + return all_ids[self.surr_offset:], self.read_offset - self.surr_offset def check_finished(self): if self.finished(): @@ -649,7 +649,7 @@ def check_finished(self): # Check stop strings if len(self.sampling_params.stop_strs) > 0: tail_str = self.tokenizer.decode( - self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :] + self.output_ids[-(self.sampling_params.stop_str_max_len + 1):] ) for stop_str in self.sampling_params.stop_strs: @@ -945,15 +945,15 @@ def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]) # NOTE: the encoder part should be considered as a whole assert len(req.prefix_indices) == 0 input_ids[i] = input_ids[i][encoder_len:] - encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len]) + encoder_out_cache_loc.append(self.out_cache_loc[pt: pt + encoder_len]) decoder_out_cache_loc.append( - self.out_cache_loc[pt + encoder_len : pt + req.extend_input_len] + self.out_cache_loc[pt + encoder_len: pt + req.extend_input_len] ) self.extend_lens[i] -= encoder_len self.extend_num_tokens -= encoder_len else: decoder_out_cache_loc.append( - self.out_cache_loc[pt : pt + req.extend_input_len] + self.out_cache_loc[pt: pt + req.extend_input_len] ) self.prefix_lens[i] -= encoder_len @@ -992,7 +992,7 @@ def prepare_for_extend(self): # Init tensors reqs = self.reqs - input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs] + input_ids = [r.fill_ids[len(r.prefix_indices):] for r in reqs] extend_num_tokens = sum(len(ids) for ids in input_ids) seq_lens = [len(r.fill_ids) for r in reqs] prefix_lens = [len(r.prefix_indices) for r in reqs] @@ -1067,8 +1067,8 @@ def prepare_for_extend(self): global_start_idx = req.logprob_start_len logprob_token_ids = req.origin_input_ids[ - global_start_idx + 1 : global_end_idx + 1 - ] + global_start_idx + 1: global_end_idx + 1 + ] extend_input_logprob_token_ids.extend(logprob_token_ids) # We will need req.extend_input_len - req.extend_logprob_start_len number of @@ -1142,7 +1142,7 @@ def prepare_for_extend(self): for i in range(bs): self.req_to_token_pool.write( (req_pool_indices[i], slice(prefix_lens[i], seq_lens[i])), - out_cache_loc[pt : pt + extend_lens[i]], + out_cache_loc[pt: pt + extend_lens[i]], ) pt += extend_lens[i] @@ -1259,18 +1259,18 @@ def get_required_tokens(num_reqs: int): if isinstance(self.tree_cache, ChunkCache): # ChunkCache does not have eviction token_indices = self.req_to_token_pool.req_to_token[ - req.req_pool_idx, : seq_lens_cpu[idx] - ] + req.req_pool_idx, : seq_lens_cpu[idx] + ] self.token_to_kv_pool_allocator.free(token_indices) self.req_to_token_pool.free(req.req_pool_idx) else: # TODO: apply more fine-grained retraction last_uncached_pos = ( - len(req.prefix_indices) // server_args.page_size - ) * server_args.page_size + len(req.prefix_indices) // server_args.page_size + ) * server_args.page_size token_indices = self.req_to_token_pool.req_to_token[ - req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx] - ] + req.req_pool_idx, last_uncached_pos: seq_lens_cpu[idx] + ] self.token_to_kv_pool_allocator.free(token_indices) self.req_to_token_pool.free(req.req_pool_idx) @@ -1294,8 +1294,8 @@ def get_required_tokens(num_reqs: int): total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs) new_estimate_ratio = ( - total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs) - ) / total_max_new_tokens + total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs) + ) / total_max_new_tokens new_estimate_ratio = min(1.0, new_estimate_ratio) return retracted_reqs, new_estimate_ratio @@ -1392,7 +1392,7 @@ def filter_batch( i for i in range(len(self.reqs)) if not self.reqs[i].finished() - and self.reqs[i] is not chunked_req_to_exclude + and self.reqs[i] is not chunked_req_to_exclude ] if keep_indices is None or len(keep_indices) == 0: From b46ebce82c00e50ec3d6073a42b6019518871cc9 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 16:00:31 +0800 Subject: [PATCH 0280/1089] Revert "Revert "more"" This reverts commit fb9b8b1ac22844cf53c435ee1b2582012abb71ca. --- python/sglang/srt/managers/expert_location.py | 4 ++-- python/sglang/srt/managers/schedule_batch.py | 2 +- python/sglang/srt/model_executor/model_runner.py | 7 ++++--- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index bf5fae6b6a9..8e3a3c4f18e 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -20,7 +20,7 @@ def from_model_config(model_config: ModelConfig): model_class, _ = get_model_architecture(model_config) if hasattr(model_class, "get_expert_location_metadata"): return model_class.get_expert_location_metadata(model_config.hf_config) - return ExpertLocationMetadata.init_dummy() + return ExpertLocationMetadata._init_dummy() @staticmethod def init_new(num_layers: int, num_logical_experts: int): @@ -41,7 +41,7 @@ def init_new(num_layers: int, num_logical_experts: int): ) @staticmethod - def init_dummy(): + def _init_dummy(): return ExpertLocationMetadata.init_new(num_layers=1, num_logical_experts=1) def local_physical_to_global_physical( diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 3a5f51f571b..eabdfae6822 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -86,7 +86,7 @@ "n_share_experts_fusion": ServerArgs.n_share_experts_fusion, "disable_shared_experts_fusion": ServerArgs.disable_shared_experts_fusion, } -global_expert_location_metadata = ExpertLocationMetadata.init_dummy() +global_expert_location_metadata: Optional[ExpertLocationMetadata] = None logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 92410ffbcca..7b83f680f24 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -47,7 +47,7 @@ from sglang.srt.lora.lora_manager import LoRAManager from sglang.srt.managers.expert_distribution import expert_distribution_recorder from sglang.srt.managers.expert_location import ExpertLocationMetadata -from sglang.srt.managers.schedule_batch import global_server_args_dict +from sglang.srt.managers.schedule_batch import global_server_args_dict, global_expert_location_metadata from sglang.srt.mem_cache.memory_pool import ( DoubleSparseTokenToKVPool, MHATokenToKVPool, @@ -173,14 +173,15 @@ def __init__( } ) + global global_expert_location_metadata + global_expert_location_metadata = expert_location_metadata + # CPU offload set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3)) # Get memory before model loading min_per_gpu_memory = self.init_torch_distributed() - self.expert_location_metadata = expert_location_metadata.clone() - # If it is a draft model tp_group can be different. self.initialize(min_per_gpu_memory) From be9c6b33ffc354c9c30749c56e273189f6c78627 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 16:00:31 +0800 Subject: [PATCH 0281/1089] Revert "more" This reverts commit ad5a7d50a85af9b4395c5eed035a2153e589d77e. --- python/sglang/srt/managers/expert_location.py | 4 ++-- python/sglang/srt/managers/schedule_batch.py | 2 +- python/sglang/srt/model_executor/model_runner.py | 7 +++---- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 8e3a3c4f18e..bf5fae6b6a9 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -20,7 +20,7 @@ def from_model_config(model_config: ModelConfig): model_class, _ = get_model_architecture(model_config) if hasattr(model_class, "get_expert_location_metadata"): return model_class.get_expert_location_metadata(model_config.hf_config) - return ExpertLocationMetadata._init_dummy() + return ExpertLocationMetadata.init_dummy() @staticmethod def init_new(num_layers: int, num_logical_experts: int): @@ -41,7 +41,7 @@ def init_new(num_layers: int, num_logical_experts: int): ) @staticmethod - def _init_dummy(): + def init_dummy(): return ExpertLocationMetadata.init_new(num_layers=1, num_logical_experts=1) def local_physical_to_global_physical( diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index eabdfae6822..3a5f51f571b 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -86,7 +86,7 @@ "n_share_experts_fusion": ServerArgs.n_share_experts_fusion, "disable_shared_experts_fusion": ServerArgs.disable_shared_experts_fusion, } -global_expert_location_metadata: Optional[ExpertLocationMetadata] = None +global_expert_location_metadata = ExpertLocationMetadata.init_dummy() logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 7b83f680f24..92410ffbcca 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -47,7 +47,7 @@ from sglang.srt.lora.lora_manager import LoRAManager from sglang.srt.managers.expert_distribution import expert_distribution_recorder from sglang.srt.managers.expert_location import ExpertLocationMetadata -from sglang.srt.managers.schedule_batch import global_server_args_dict, global_expert_location_metadata +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.mem_cache.memory_pool import ( DoubleSparseTokenToKVPool, MHATokenToKVPool, @@ -173,15 +173,14 @@ def __init__( } ) - global global_expert_location_metadata - global_expert_location_metadata = expert_location_metadata - # CPU offload set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3)) # Get memory before model loading min_per_gpu_memory = self.init_torch_distributed() + self.expert_location_metadata = expert_location_metadata.clone() + # If it is a draft model tp_group can be different. self.initialize(min_per_gpu_memory) From 79a6b78eb5ab339bc79c5b868a16c960fed647a8 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 16:00:31 +0800 Subject: [PATCH 0282/1089] Revert "more" This reverts commit d542934186fd785b57169d7029016bd8f07632c1. --- python/sglang/srt/managers/expert_location.py | 4 +- python/sglang/srt/managers/schedule_batch.py | 39 +++++++++---------- .../sglang/srt/model_executor/model_runner.py | 2 +- 3 files changed, 21 insertions(+), 24 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index bf5fae6b6a9..8e3a3c4f18e 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -20,7 +20,7 @@ def from_model_config(model_config: ModelConfig): model_class, _ = get_model_architecture(model_config) if hasattr(model_class, "get_expert_location_metadata"): return model_class.get_expert_location_metadata(model_config.hf_config) - return ExpertLocationMetadata.init_dummy() + return ExpertLocationMetadata._init_dummy() @staticmethod def init_new(num_layers: int, num_logical_experts: int): @@ -41,7 +41,7 @@ def init_new(num_layers: int, num_logical_experts: int): ) @staticmethod - def init_dummy(): + def _init_dummy(): return ExpertLocationMetadata.init_new(num_layers=1, num_logical_experts=1) def local_physical_to_global_physical( diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 3a5f51f571b..cce17729e50 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -3,8 +3,6 @@ import hashlib from enum import Enum, auto -from sglang.srt.managers.expert_location import ExpertLocationMetadata - # Copyright 2023-2024 SGLang Team # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -86,7 +84,6 @@ "n_share_experts_fusion": ServerArgs.n_share_experts_fusion, "disable_shared_experts_fusion": ServerArgs.disable_shared_experts_fusion, } -global_expert_location_metadata = ExpertLocationMetadata.init_dummy() logger = logging.getLogger(__name__) @@ -608,7 +605,7 @@ def init_incremental_detokenize(self): ) all_ids = self.origin_input_ids_unpadded + self.output_ids - return all_ids[self.surr_offset:], self.read_offset - self.surr_offset + return all_ids[self.surr_offset :], self.read_offset - self.surr_offset def check_finished(self): if self.finished(): @@ -649,7 +646,7 @@ def check_finished(self): # Check stop strings if len(self.sampling_params.stop_strs) > 0: tail_str = self.tokenizer.decode( - self.output_ids[-(self.sampling_params.stop_str_max_len + 1):] + self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :] ) for stop_str in self.sampling_params.stop_strs: @@ -945,15 +942,15 @@ def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]) # NOTE: the encoder part should be considered as a whole assert len(req.prefix_indices) == 0 input_ids[i] = input_ids[i][encoder_len:] - encoder_out_cache_loc.append(self.out_cache_loc[pt: pt + encoder_len]) + encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len]) decoder_out_cache_loc.append( - self.out_cache_loc[pt + encoder_len: pt + req.extend_input_len] + self.out_cache_loc[pt + encoder_len : pt + req.extend_input_len] ) self.extend_lens[i] -= encoder_len self.extend_num_tokens -= encoder_len else: decoder_out_cache_loc.append( - self.out_cache_loc[pt: pt + req.extend_input_len] + self.out_cache_loc[pt : pt + req.extend_input_len] ) self.prefix_lens[i] -= encoder_len @@ -992,7 +989,7 @@ def prepare_for_extend(self): # Init tensors reqs = self.reqs - input_ids = [r.fill_ids[len(r.prefix_indices):] for r in reqs] + input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs] extend_num_tokens = sum(len(ids) for ids in input_ids) seq_lens = [len(r.fill_ids) for r in reqs] prefix_lens = [len(r.prefix_indices) for r in reqs] @@ -1067,8 +1064,8 @@ def prepare_for_extend(self): global_start_idx = req.logprob_start_len logprob_token_ids = req.origin_input_ids[ - global_start_idx + 1: global_end_idx + 1 - ] + global_start_idx + 1 : global_end_idx + 1 + ] extend_input_logprob_token_ids.extend(logprob_token_ids) # We will need req.extend_input_len - req.extend_logprob_start_len number of @@ -1142,7 +1139,7 @@ def prepare_for_extend(self): for i in range(bs): self.req_to_token_pool.write( (req_pool_indices[i], slice(prefix_lens[i], seq_lens[i])), - out_cache_loc[pt: pt + extend_lens[i]], + out_cache_loc[pt : pt + extend_lens[i]], ) pt += extend_lens[i] @@ -1259,18 +1256,18 @@ def get_required_tokens(num_reqs: int): if isinstance(self.tree_cache, ChunkCache): # ChunkCache does not have eviction token_indices = self.req_to_token_pool.req_to_token[ - req.req_pool_idx, : seq_lens_cpu[idx] - ] + req.req_pool_idx, : seq_lens_cpu[idx] + ] self.token_to_kv_pool_allocator.free(token_indices) self.req_to_token_pool.free(req.req_pool_idx) else: # TODO: apply more fine-grained retraction last_uncached_pos = ( - len(req.prefix_indices) // server_args.page_size - ) * server_args.page_size + len(req.prefix_indices) // server_args.page_size + ) * server_args.page_size token_indices = self.req_to_token_pool.req_to_token[ - req.req_pool_idx, last_uncached_pos: seq_lens_cpu[idx] - ] + req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx] + ] self.token_to_kv_pool_allocator.free(token_indices) self.req_to_token_pool.free(req.req_pool_idx) @@ -1294,8 +1291,8 @@ def get_required_tokens(num_reqs: int): total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs) new_estimate_ratio = ( - total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs) - ) / total_max_new_tokens + total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs) + ) / total_max_new_tokens new_estimate_ratio = min(1.0, new_estimate_ratio) return retracted_reqs, new_estimate_ratio @@ -1392,7 +1389,7 @@ def filter_batch( i for i in range(len(self.reqs)) if not self.reqs[i].finished() - and self.reqs[i] is not chunked_req_to_exclude + and self.reqs[i] is not chunked_req_to_exclude ] if keep_indices is None or len(keep_indices) == 0: diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 92410ffbcca..36b691eb4ed 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -196,7 +196,7 @@ def initialize(self, min_per_gpu_memory: float): expert_distribution_recorder.initialize( server_args, - global_expert_location_metadata, + self.expert_location_metadata, # TODO handle DP!=TP case rank=self.tp_rank, ) From 609b9f22d39bf7fd237c2f7acc5a28431282e5b0 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 16:01:02 +0800 Subject: [PATCH 0283/1089] fmt --- python/sglang/srt/managers/expert_location.py | 9 +++-- python/sglang/srt/managers/schedule_batch.py | 36 +++++++++---------- .../sglang/srt/model_executor/model_runner.py | 10 ++++-- 3 files changed, 31 insertions(+), 24 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 29159bf85a5..7f549bd2114 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -1,6 +1,7 @@ from dataclasses import dataclass import torch + from sglang.srt.configs.model_config import ModelConfig from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.model_loader import get_model_architecture @@ -26,7 +27,7 @@ def from_model_config(model_config: ModelConfig): def init_new(num_layers: int, num_logical_experts: int, is_dummy: bool = False): # TODO handle more complex cases like duplicating experts on different GPUs num_local_physical_experts = ( - num_logical_experts // get_tensor_model_parallel_world_size() + num_logical_experts // get_tensor_model_parallel_world_size() ) num_physical_experts = num_logical_experts @@ -43,10 +44,12 @@ def init_new(num_layers: int, num_logical_experts: int, is_dummy: bool = False): @staticmethod def init_dummy(): - return ExpertLocationMetadata.init_new(num_layers=1, num_logical_experts=1, is_dummy=True) + return ExpertLocationMetadata.init_new( + num_layers=1, num_logical_experts=1, is_dummy=True + ) def local_physical_to_global_physical( - self, rank: int, local_physical_expert_index: int + self, rank: int, local_physical_expert_index: int ): return self.num_local_physical_experts * rank + local_physical_expert_index diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index dae80d54f2c..0fb90ae778e 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -609,7 +609,7 @@ def init_incremental_detokenize(self): ) all_ids = self.origin_input_ids_unpadded + self.output_ids - return all_ids[self.surr_offset:], self.read_offset - self.surr_offset + return all_ids[self.surr_offset :], self.read_offset - self.surr_offset def check_finished(self): if self.finished(): @@ -650,7 +650,7 @@ def check_finished(self): # Check stop strings if len(self.sampling_params.stop_strs) > 0: tail_str = self.tokenizer.decode( - self.output_ids[-(self.sampling_params.stop_str_max_len + 1):] + self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :] ) for stop_str in self.sampling_params.stop_strs: @@ -946,15 +946,15 @@ def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]) # NOTE: the encoder part should be considered as a whole assert len(req.prefix_indices) == 0 input_ids[i] = input_ids[i][encoder_len:] - encoder_out_cache_loc.append(self.out_cache_loc[pt: pt + encoder_len]) + encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len]) decoder_out_cache_loc.append( - self.out_cache_loc[pt + encoder_len: pt + req.extend_input_len] + self.out_cache_loc[pt + encoder_len : pt + req.extend_input_len] ) self.extend_lens[i] -= encoder_len self.extend_num_tokens -= encoder_len else: decoder_out_cache_loc.append( - self.out_cache_loc[pt: pt + req.extend_input_len] + self.out_cache_loc[pt : pt + req.extend_input_len] ) self.prefix_lens[i] -= encoder_len @@ -993,7 +993,7 @@ def prepare_for_extend(self): # Init tensors reqs = self.reqs - input_ids = [r.fill_ids[len(r.prefix_indices):] for r in reqs] + input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs] extend_num_tokens = sum(len(ids) for ids in input_ids) seq_lens = [len(r.fill_ids) for r in reqs] prefix_lens = [len(r.prefix_indices) for r in reqs] @@ -1068,8 +1068,8 @@ def prepare_for_extend(self): global_start_idx = req.logprob_start_len logprob_token_ids = req.origin_input_ids[ - global_start_idx + 1: global_end_idx + 1 - ] + global_start_idx + 1 : global_end_idx + 1 + ] extend_input_logprob_token_ids.extend(logprob_token_ids) # We will need req.extend_input_len - req.extend_logprob_start_len number of @@ -1143,7 +1143,7 @@ def prepare_for_extend(self): for i in range(bs): self.req_to_token_pool.write( (req_pool_indices[i], slice(prefix_lens[i], seq_lens[i])), - out_cache_loc[pt: pt + extend_lens[i]], + out_cache_loc[pt : pt + extend_lens[i]], ) pt += extend_lens[i] @@ -1260,18 +1260,18 @@ def get_required_tokens(num_reqs: int): if isinstance(self.tree_cache, ChunkCache): # ChunkCache does not have eviction token_indices = self.req_to_token_pool.req_to_token[ - req.req_pool_idx, : seq_lens_cpu[idx] - ] + req.req_pool_idx, : seq_lens_cpu[idx] + ] self.token_to_kv_pool_allocator.free(token_indices) self.req_to_token_pool.free(req.req_pool_idx) else: # TODO: apply more fine-grained retraction last_uncached_pos = ( - len(req.prefix_indices) // server_args.page_size - ) * server_args.page_size + len(req.prefix_indices) // server_args.page_size + ) * server_args.page_size token_indices = self.req_to_token_pool.req_to_token[ - req.req_pool_idx, last_uncached_pos: seq_lens_cpu[idx] - ] + req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx] + ] self.token_to_kv_pool_allocator.free(token_indices) self.req_to_token_pool.free(req.req_pool_idx) @@ -1295,8 +1295,8 @@ def get_required_tokens(num_reqs: int): total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs) new_estimate_ratio = ( - total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs) - ) / total_max_new_tokens + total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs) + ) / total_max_new_tokens new_estimate_ratio = min(1.0, new_estimate_ratio) return retracted_reqs, new_estimate_ratio @@ -1393,7 +1393,7 @@ def filter_batch( i for i in range(len(self.reqs)) if not self.reqs[i].finished() - and self.reqs[i] is not chunked_req_to_exclude + and self.reqs[i] is not chunked_req_to_exclude ] if keep_indices is None or len(keep_indices) == 0: diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index a473d8b9905..fe684980f40 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -24,6 +24,7 @@ import torch import torch.distributed as dist + from sglang.srt.configs.device_config import DeviceConfig from sglang.srt.configs.load_config import LoadConfig from sglang.srt.configs.model_config import AttentionArch, ModelConfig @@ -46,7 +47,10 @@ from sglang.srt.lora.lora_manager import LoRAManager from sglang.srt.managers.expert_distribution import expert_distribution_recorder from sglang.srt.managers.expert_location import ExpertLocationMetadata -from sglang.srt.managers.schedule_batch import global_server_args_dict, global_expert_location_metadata +from sglang.srt.managers.schedule_batch import ( + global_expert_location_metadata, + global_server_args_dict, +) from sglang.srt.mem_cache.memory_pool import ( DoubleSparseTokenToKVPool, MHATokenToKVPool, @@ -174,7 +178,7 @@ def __init__( global_expert_location_metadata.update(expert_location_metadata) # CPU offload - set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024 ** 3)) + set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3)) # Get memory before model loading min_per_gpu_memory = self.init_torch_distributed() @@ -921,7 +925,7 @@ def init_double_sparsity_channel_config(self, selected_channel): key = "model.layers." + str(i) + ".self_attn" + selected_channel self.sorted_channels.append( torch.tensor(channel_config[key])[ - :, : self.server_args.ds_heavy_channel_num + :, : self.server_args.ds_heavy_channel_num ] .contiguous() .cuda() From 1cbfdf70096f022d2fba998fcf42f61c4c3e1d67 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 16:02:02 +0800 Subject: [PATCH 0284/1089] more --- python/sglang/srt/layers/moe/ep_moe/layer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index cffc15a4269..ccde8c87f92 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -3,6 +3,8 @@ import torch +from sglang.srt.managers.schedule_batch import global_expert_location_metadata + try: from deep_gemm import ( get_col_major_tma_aligned_tensor, @@ -410,7 +412,7 @@ def weight_loader( shard_id: str, expert_id: int, ) -> None: - physical_expert_ids = expert_location_metadata.logical_to_global_physical( + physical_expert_ids = global_expert_location_metadata.logical_to_global_physical( expert_id ) for physical_expert_id in physical_expert_ids: From 56b2b8047126f00f64887bf4ad1003e5d3acf9ac Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 16:02:48 +0800 Subject: [PATCH 0285/1089] more --- python/sglang/srt/layers/moe/ep_moe/layer.py | 4 ++-- python/sglang/srt/managers/expert_location.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index ccde8c87f92..e61e2e02cdd 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -412,8 +412,8 @@ def weight_loader( shard_id: str, expert_id: int, ) -> None: - physical_expert_ids = global_expert_location_metadata.logical_to_global_physical( - expert_id + physical_expert_ids = ( + global_expert_location_metadata.logical_to_global_physical(expert_id) ) for physical_expert_id in physical_expert_ids: self._weight_loader_physical( diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index ea1d3e4f9e3..b4b614a5aaa 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -57,7 +57,7 @@ def global_physical_to_local_physical(self, global_physical_expert_index: int): return global_physical_expert_index % self.num_local_physical_experts def logical_to_global_physical(self, logical_expert_id: int): - return logical_expert_id # TODO add a logical_to_physical_map + return [logical_expert_id] # TODO add a logical_to_physical_map def update(self, other: "ExpertLocationMetadata"): if self.is_dummy: From 77fc845e5701fc59f37fd4c14bab302baf11e42b Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 16:04:02 +0800 Subject: [PATCH 0286/1089] more --- python/sglang/srt/layers/moe/topk.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index a0bd81d43ec..f979e18a983 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -12,12 +12,10 @@ # limitations under the License. # ============================================================================== -import os from typing import Callable, Optional import torch import torch.nn.functional as F - from sglang.srt.managers.expert_distribution import expert_distribution_recorder from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip @@ -298,6 +296,9 @@ def select_experts( renormalize=renormalize, ) + # TODO this is inefficient, and I will fuse into existing kernels + TODO + expert_distribution_recorder.on_select_experts(topk_ids=topk_ids) return topk_weights, topk_ids From f7c96de7416d78ccb05a75e123bc1e08d9d62f96 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 16:04:21 +0800 Subject: [PATCH 0287/1089] more --- python/sglang/srt/layers/moe/topk.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index f979e18a983..3633f6997da 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -297,7 +297,7 @@ def select_experts( ) # TODO this is inefficient, and I will fuse into existing kernels - TODO + topk_ids = partial_logical_to_physical_map[topk_ids] expert_distribution_recorder.on_select_experts(topk_ids=topk_ids) From 66f641a05300392c17fdb2ada6ed49acfec97970 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 16:04:45 +0800 Subject: [PATCH 0288/1089] more --- python/sglang/srt/layers/moe/topk.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 3633f6997da..38b00ee8cc5 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -17,7 +17,7 @@ import torch import torch.nn.functional as F from sglang.srt.managers.expert_distribution import expert_distribution_recorder -from sglang.srt.managers.schedule_batch import global_server_args_dict +from sglang.srt.managers.schedule_batch import global_server_args_dict, global_expert_location_metadata from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip _is_cuda = is_cuda() @@ -297,7 +297,7 @@ def select_experts( ) # TODO this is inefficient, and I will fuse into existing kernels - topk_ids = partial_logical_to_physical_map[topk_ids] + topk_ids = global_expert_location_metadata.partial_logical_to_physical_map[topk_ids] expert_distribution_recorder.on_select_experts(topk_ids=topk_ids) From 4c4ee66b4916dd98cb7af1a5eb870fec40191bc0 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 16:05:39 +0800 Subject: [PATCH 0289/1089] more --- python/sglang/srt/managers/expert_location.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index b4b614a5aaa..06589e704d9 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -1,7 +1,6 @@ from dataclasses import dataclass import torch - from sglang.srt.configs.model_config import ModelConfig from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.model_loader import get_model_architecture @@ -15,6 +14,7 @@ class ExpertLocationMetadata: num_logical_experts: int # will have a `logical_to_physical_map` later physical_to_logical_map: torch.Tensor # (layers, num_physical_experts) + partial_logical_to_physical_map: torch.Tensor # (layers, num_logical_experts) @staticmethod def from_model_config(model_config: ModelConfig): @@ -27,11 +27,11 @@ def from_model_config(model_config: ModelConfig): def init_new(num_layers: int, num_logical_experts: int, is_dummy: bool = False): # TODO handle more complex cases like duplicating experts on different GPUs num_local_physical_experts = ( - num_logical_experts // get_tensor_model_parallel_world_size() + num_logical_experts // get_tensor_model_parallel_world_size() ) num_physical_experts = num_logical_experts - return ExpertLocationMetadata( + output = ExpertLocationMetadata( is_dummy=is_dummy, num_layers=num_layers, num_logical_experts=num_logical_experts, @@ -40,7 +40,10 @@ def init_new(num_layers: int, num_logical_experts: int, is_dummy: bool = False): num_layers=num_layers, num_physical_experts=num_physical_experts, ), + partial_logical_to_physical_map=None, ) + output._rebuild() + return output @staticmethod def init_dummy(): @@ -49,7 +52,7 @@ def init_dummy(): ) def local_physical_to_global_physical( - self, rank: int, local_physical_expert_index: int + self, rank: int, local_physical_expert_index: int ): return self.num_local_physical_experts * rank + local_physical_expert_index @@ -59,6 +62,9 @@ def global_physical_to_local_physical(self, global_physical_expert_index: int): def logical_to_global_physical(self, logical_expert_id: int): return [logical_expert_id] # TODO add a logical_to_physical_map + def _rebuild(self): + self.partial_logical_to_physical_map = TODO + def update(self, other: "ExpertLocationMetadata"): if self.is_dummy: self._update_unconditionally(other) From 59d74fc875d19fe393f0f376106aa04eb3339834 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 16:06:11 +0800 Subject: [PATCH 0290/1089] more --- python/sglang/srt/managers/expert_location.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 06589e704d9..892aee9b122 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -40,7 +40,7 @@ def init_new(num_layers: int, num_logical_experts: int, is_dummy: bool = False): num_layers=num_layers, num_physical_experts=num_physical_experts, ), - partial_logical_to_physical_map=None, + partial_logical_to_physical_map=TODO, ) output._rebuild() return output @@ -63,13 +63,14 @@ def logical_to_global_physical(self, logical_expert_id: int): return [logical_expert_id] # TODO add a logical_to_physical_map def _rebuild(self): - self.partial_logical_to_physical_map = TODO + self.partial_logical_to_physical_map[...] = TODO def update(self, other: "ExpertLocationMetadata"): if self.is_dummy: self._update_unconditionally(other) else: self._update_partial(other) + self._rebuild() def _update_unconditionally(self, other: "ExpertLocationMetadata"): for field in _UPDATE_FIELDS_TRIVIAL: @@ -82,7 +83,7 @@ def _update_partial(self, other: "ExpertLocationMetadata"): assert getattr(self, field) == getattr(other, field) for field in _UPDATE_FIELDS_TENSOR: # Cannot update address to avoid breaking CUDA graph - getattr(self, field)[...] = getattr(other, field) + getattr(self, field) = getattr(other, field) _UPDATE_FIELDS_TRIVIAL = [ From b642d61fa455e2bb1809fb3ed373ff6d54d61fe7 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 16:06:44 +0800 Subject: [PATCH 0291/1089] more --- python/sglang/srt/layers/moe/topk.py | 2 +- python/sglang/srt/managers/expert_location.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 38b00ee8cc5..10bba0d29c4 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -297,7 +297,7 @@ def select_experts( ) # TODO this is inefficient, and I will fuse into existing kernels - topk_ids = global_expert_location_metadata.partial_logical_to_physical_map[topk_ids] + topk_ids = global_expert_location_metadata.chosen_logical_to_physical_map[topk_ids] expert_distribution_recorder.on_select_experts(topk_ids=topk_ids) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 892aee9b122..1d572ed8ad6 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -14,7 +14,7 @@ class ExpertLocationMetadata: num_logical_experts: int # will have a `logical_to_physical_map` later physical_to_logical_map: torch.Tensor # (layers, num_physical_experts) - partial_logical_to_physical_map: torch.Tensor # (layers, num_logical_experts) + chosen_logical_to_physical_map: torch.Tensor # (layers, num_logical_experts) @staticmethod def from_model_config(model_config: ModelConfig): @@ -40,7 +40,7 @@ def init_new(num_layers: int, num_logical_experts: int, is_dummy: bool = False): num_layers=num_layers, num_physical_experts=num_physical_experts, ), - partial_logical_to_physical_map=TODO, + chosen_logical_to_physical_map=TODO, ) output._rebuild() return output @@ -63,7 +63,7 @@ def logical_to_global_physical(self, logical_expert_id: int): return [logical_expert_id] # TODO add a logical_to_physical_map def _rebuild(self): - self.partial_logical_to_physical_map[...] = TODO + self.chosen_logical_to_physical_map[...] = TODO def update(self, other: "ExpertLocationMetadata"): if self.is_dummy: From 8952dae9e365a83a972453c557a6649e5701d506 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 16:07:35 +0800 Subject: [PATCH 0292/1089] more --- python/sglang/srt/managers/expert_location.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 1d572ed8ad6..533c867d0da 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -83,7 +83,7 @@ def _update_partial(self, other: "ExpertLocationMetadata"): assert getattr(self, field) == getattr(other, field) for field in _UPDATE_FIELDS_TENSOR: # Cannot update address to avoid breaking CUDA graph - getattr(self, field) = getattr(other, field) + getattr(self, field)[...] = getattr(other, field) _UPDATE_FIELDS_TRIVIAL = [ From 8e8dd6c67e41df89d1cfa5716d9fdf8a5ce5e601 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 16:07:49 +0800 Subject: [PATCH 0293/1089] fmt --- python/sglang/srt/layers/moe/topk.py | 6 +++++- python/sglang/srt/managers/expert_location.py | 5 +++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 10bba0d29c4..f3d42512ec4 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -16,8 +16,12 @@ import torch import torch.nn.functional as F + from sglang.srt.managers.expert_distribution import expert_distribution_recorder -from sglang.srt.managers.schedule_batch import global_server_args_dict, global_expert_location_metadata +from sglang.srt.managers.schedule_batch import ( + global_expert_location_metadata, + global_server_args_dict, +) from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip _is_cuda = is_cuda() diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 533c867d0da..bddfcaa3bdd 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -1,6 +1,7 @@ from dataclasses import dataclass import torch + from sglang.srt.configs.model_config import ModelConfig from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.model_loader import get_model_architecture @@ -27,7 +28,7 @@ def from_model_config(model_config: ModelConfig): def init_new(num_layers: int, num_logical_experts: int, is_dummy: bool = False): # TODO handle more complex cases like duplicating experts on different GPUs num_local_physical_experts = ( - num_logical_experts // get_tensor_model_parallel_world_size() + num_logical_experts // get_tensor_model_parallel_world_size() ) num_physical_experts = num_logical_experts @@ -52,7 +53,7 @@ def init_dummy(): ) def local_physical_to_global_physical( - self, rank: int, local_physical_expert_index: int + self, rank: int, local_physical_expert_index: int ): return self.num_local_physical_experts * rank + local_physical_expert_index From 3dfd35923bba6f77abcf97836d967edeb00a14bb Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 16:09:47 +0800 Subject: [PATCH 0294/1089] more --- python/sglang/srt/managers/expert_location.py | 25 +++---------------- 1 file changed, 4 insertions(+), 21 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 7f549bd2114..7ef8f3bcecf 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -1,7 +1,6 @@ from dataclasses import dataclass import torch - from sglang.srt.configs.model_config import ModelConfig from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.model_loader import get_model_architecture @@ -9,7 +8,6 @@ @dataclass class ExpertLocationMetadata: - is_dummy: bool num_layers: int num_local_physical_experts: int num_logical_experts: int @@ -24,15 +22,14 @@ def from_model_config(model_config: ModelConfig): return ExpertLocationMetadata.init_dummy() @staticmethod - def init_new(num_layers: int, num_logical_experts: int, is_dummy: bool = False): + def init_new(num_layers: int, num_logical_experts: int): # TODO handle more complex cases like duplicating experts on different GPUs num_local_physical_experts = ( - num_logical_experts // get_tensor_model_parallel_world_size() + num_logical_experts // get_tensor_model_parallel_world_size() ) num_physical_experts = num_logical_experts return ExpertLocationMetadata( - is_dummy=is_dummy, num_layers=num_layers, num_logical_experts=num_logical_experts, num_local_physical_experts=num_local_physical_experts, @@ -44,12 +41,10 @@ def init_new(num_layers: int, num_logical_experts: int, is_dummy: bool = False): @staticmethod def init_dummy(): - return ExpertLocationMetadata.init_new( - num_layers=1, num_logical_experts=1, is_dummy=True - ) + return ExpertLocationMetadata.init_new(num_layers=1, num_logical_experts=1) def local_physical_to_global_physical( - self, rank: int, local_physical_expert_index: int + self, rank: int, local_physical_expert_index: int ): return self.num_local_physical_experts * rank + local_physical_expert_index @@ -57,18 +52,6 @@ def global_physical_to_local_physical(self, global_physical_expert_index: int): return global_physical_expert_index % self.num_local_physical_experts def update(self, other: "ExpertLocationMetadata"): - if self.is_dummy: - self._update_unconditionally(other) - else: - self._update_partial(other) - - def _update_unconditionally(self, other: "ExpertLocationMetadata"): - for field in _UPDATE_FIELDS_TRIVIAL: - setattr(self, field, getattr(other, field)) - for field in _UPDATE_FIELDS_TENSOR: - setattr(self, field, getattr(other, field).detach().clone()) - - def _update_partial(self, other: "ExpertLocationMetadata"): for field in _UPDATE_FIELDS_TRIVIAL: assert getattr(self, field) == getattr(other, field) for field in _UPDATE_FIELDS_TENSOR: From 740b8e7dc1035d03f2b8e519ad6385734646a10e Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 16:09:53 +0800 Subject: [PATCH 0295/1089] more --- python/sglang/srt/managers/expert_location.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 7ef8f3bcecf..66c784be393 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -19,7 +19,7 @@ def from_model_config(model_config: ModelConfig): model_class, _ = get_model_architecture(model_config) if hasattr(model_class, "get_expert_location_metadata"): return model_class.get_expert_location_metadata(model_config.hf_config) - return ExpertLocationMetadata.init_dummy() + return ExpertLocationMetadata._init_dummy() @staticmethod def init_new(num_layers: int, num_logical_experts: int): @@ -40,7 +40,7 @@ def init_new(num_layers: int, num_logical_experts: int): ) @staticmethod - def init_dummy(): + def _init_dummy(): return ExpertLocationMetadata.init_new(num_layers=1, num_logical_experts=1) def local_physical_to_global_physical( From 5a6305e9e66666f9294475a9cd6b7fbbdb530026 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 16:10:07 +0800 Subject: [PATCH 0296/1089] more --- python/sglang/srt/managers/expert_location.py | 23 ++++++++----------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 66c784be393..d132874cccb 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -52,23 +52,20 @@ def global_physical_to_local_physical(self, global_physical_expert_index: int): return global_physical_expert_index % self.num_local_physical_experts def update(self, other: "ExpertLocationMetadata"): - for field in _UPDATE_FIELDS_TRIVIAL: + for field in [ + "is_dummy", + "num_layers", + "num_local_physical_experts", + "num_logical_experts", + ]: assert getattr(self, field) == getattr(other, field) - for field in _UPDATE_FIELDS_TENSOR: + + for field in [ + "physical_to_logical_map", + ]: # Cannot update address to avoid breaking CUDA graph getattr(self, field)[...] = getattr(other, field) -_UPDATE_FIELDS_TRIVIAL = [ - "is_dummy", - "num_layers", - "num_local_physical_experts", - "num_logical_experts", -] -_UPDATE_FIELDS_TENSOR = [ - "physical_to_logical_map", -] - - def _create_vanilla_physical_to_logical_map(num_layers: int, num_physical_experts: int): return torch.arange(0, num_physical_experts).repeat(num_layers, 1) From 2b2da3161f8f81f0e11c415e25c870cb6004f081 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 16:10:15 +0800 Subject: [PATCH 0297/1089] more --- python/sglang/srt/managers/expert_location.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index d132874cccb..fbb3d2a5564 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -53,7 +53,6 @@ def global_physical_to_local_physical(self, global_physical_expert_index: int): def update(self, other: "ExpertLocationMetadata"): for field in [ - "is_dummy", "num_layers", "num_local_physical_experts", "num_logical_experts", From 1ea467b77049a8e2d8f3b35915b07dcf1476aa06 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 16:11:11 +0800 Subject: [PATCH 0298/1089] more --- python/sglang/srt/managers/schedule_batch.py | 49 ++++++++++++-------- 1 file changed, 30 insertions(+), 19 deletions(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 0fb90ae778e..8974cf976bd 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -87,7 +87,18 @@ "disable_shared_experts_fusion": ServerArgs.disable_shared_experts_fusion, } -global_expert_location_metadata = ExpertLocationMetadata.init_dummy() +_global_expert_location_metadata: Optional[ExpertLocationMetadata] = None + + +def get_global_expert_location_metadata(): + return _global_expert_location_metadata + + +def set_global_expert_location_metadata(value): + global _global_expert_location_metadata + assert _global_expert_location_metadata is None + _global_expert_location_metadata = value + logger = logging.getLogger(__name__) @@ -609,7 +620,7 @@ def init_incremental_detokenize(self): ) all_ids = self.origin_input_ids_unpadded + self.output_ids - return all_ids[self.surr_offset :], self.read_offset - self.surr_offset + return all_ids[self.surr_offset:], self.read_offset - self.surr_offset def check_finished(self): if self.finished(): @@ -650,7 +661,7 @@ def check_finished(self): # Check stop strings if len(self.sampling_params.stop_strs) > 0: tail_str = self.tokenizer.decode( - self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :] + self.output_ids[-(self.sampling_params.stop_str_max_len + 1):] ) for stop_str in self.sampling_params.stop_strs: @@ -946,15 +957,15 @@ def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]) # NOTE: the encoder part should be considered as a whole assert len(req.prefix_indices) == 0 input_ids[i] = input_ids[i][encoder_len:] - encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len]) + encoder_out_cache_loc.append(self.out_cache_loc[pt: pt + encoder_len]) decoder_out_cache_loc.append( - self.out_cache_loc[pt + encoder_len : pt + req.extend_input_len] + self.out_cache_loc[pt + encoder_len: pt + req.extend_input_len] ) self.extend_lens[i] -= encoder_len self.extend_num_tokens -= encoder_len else: decoder_out_cache_loc.append( - self.out_cache_loc[pt : pt + req.extend_input_len] + self.out_cache_loc[pt: pt + req.extend_input_len] ) self.prefix_lens[i] -= encoder_len @@ -993,7 +1004,7 @@ def prepare_for_extend(self): # Init tensors reqs = self.reqs - input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs] + input_ids = [r.fill_ids[len(r.prefix_indices):] for r in reqs] extend_num_tokens = sum(len(ids) for ids in input_ids) seq_lens = [len(r.fill_ids) for r in reqs] prefix_lens = [len(r.prefix_indices) for r in reqs] @@ -1068,8 +1079,8 @@ def prepare_for_extend(self): global_start_idx = req.logprob_start_len logprob_token_ids = req.origin_input_ids[ - global_start_idx + 1 : global_end_idx + 1 - ] + global_start_idx + 1: global_end_idx + 1 + ] extend_input_logprob_token_ids.extend(logprob_token_ids) # We will need req.extend_input_len - req.extend_logprob_start_len number of @@ -1143,7 +1154,7 @@ def prepare_for_extend(self): for i in range(bs): self.req_to_token_pool.write( (req_pool_indices[i], slice(prefix_lens[i], seq_lens[i])), - out_cache_loc[pt : pt + extend_lens[i]], + out_cache_loc[pt: pt + extend_lens[i]], ) pt += extend_lens[i] @@ -1260,18 +1271,18 @@ def get_required_tokens(num_reqs: int): if isinstance(self.tree_cache, ChunkCache): # ChunkCache does not have eviction token_indices = self.req_to_token_pool.req_to_token[ - req.req_pool_idx, : seq_lens_cpu[idx] - ] + req.req_pool_idx, : seq_lens_cpu[idx] + ] self.token_to_kv_pool_allocator.free(token_indices) self.req_to_token_pool.free(req.req_pool_idx) else: # TODO: apply more fine-grained retraction last_uncached_pos = ( - len(req.prefix_indices) // server_args.page_size - ) * server_args.page_size + len(req.prefix_indices) // server_args.page_size + ) * server_args.page_size token_indices = self.req_to_token_pool.req_to_token[ - req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx] - ] + req.req_pool_idx, last_uncached_pos: seq_lens_cpu[idx] + ] self.token_to_kv_pool_allocator.free(token_indices) self.req_to_token_pool.free(req.req_pool_idx) @@ -1295,8 +1306,8 @@ def get_required_tokens(num_reqs: int): total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs) new_estimate_ratio = ( - total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs) - ) / total_max_new_tokens + total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs) + ) / total_max_new_tokens new_estimate_ratio = min(1.0, new_estimate_ratio) return retracted_reqs, new_estimate_ratio @@ -1393,7 +1404,7 @@ def filter_batch( i for i in range(len(self.reqs)) if not self.reqs[i].finished() - and self.reqs[i] is not chunked_req_to_exclude + and self.reqs[i] is not chunked_req_to_exclude ] if keep_indices is None or len(keep_indices) == 0: From a90120a0812fb18d941b9bc81ebdbdb03561c9fc Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 16:11:41 +0800 Subject: [PATCH 0299/1089] more --- python/sglang/srt/model_executor/model_runner.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index fe684980f40..af998956271 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -48,8 +48,7 @@ from sglang.srt.managers.expert_distribution import expert_distribution_recorder from sglang.srt.managers.expert_location import ExpertLocationMetadata from sglang.srt.managers.schedule_batch import ( - global_expert_location_metadata, - global_server_args_dict, + global_server_args_dict, set_global_expert_location_metadata, get_global_expert_location_metadata, ) from sglang.srt.mem_cache.memory_pool import ( DoubleSparseTokenToKVPool, @@ -175,7 +174,7 @@ def __init__( "use_mla_backend": self.use_mla_backend, } ) - global_expert_location_metadata.update(expert_location_metadata) + set_global_expert_location_metadata(expert_location_metadata) # CPU offload set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3)) @@ -198,7 +197,7 @@ def initialize(self, min_per_gpu_memory: float): expert_distribution_recorder.initialize( server_args, - global_expert_location_metadata, + get_global_expert_location_metadata(), # TODO handle DP!=TP case rank=self.tp_rank, ) From a4d47fac6698bc7e2b2fcdb5980a301644640485 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 16:13:30 +0800 Subject: [PATCH 0300/1089] more --- python/sglang/srt/managers/expert_location.py | 20 +++++++++++-------- .../sglang/srt/model_executor/model_runner.py | 5 ++--- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index fbb3d2a5564..e4d36a75882 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -14,6 +14,8 @@ class ExpertLocationMetadata: # will have a `logical_to_physical_map` later physical_to_logical_map: torch.Tensor # (layers, num_physical_experts) + # -------------------------------- construction and mutation ------------------------------------ + @staticmethod def from_model_config(model_config: ModelConfig): model_class, _ = get_model_architecture(model_config) @@ -43,14 +45,6 @@ def init_new(num_layers: int, num_logical_experts: int): def _init_dummy(): return ExpertLocationMetadata.init_new(num_layers=1, num_logical_experts=1) - def local_physical_to_global_physical( - self, rank: int, local_physical_expert_index: int - ): - return self.num_local_physical_experts * rank + local_physical_expert_index - - def global_physical_to_local_physical(self, global_physical_expert_index: int): - return global_physical_expert_index % self.num_local_physical_experts - def update(self, other: "ExpertLocationMetadata"): for field in [ "num_layers", @@ -65,6 +59,16 @@ def update(self, other: "ExpertLocationMetadata"): # Cannot update address to avoid breaking CUDA graph getattr(self, field)[...] = getattr(other, field) + # -------------------------------- usage ------------------------------------ + + def local_physical_to_global_physical( + self, rank: int, local_physical_expert_index: int + ): + return self.num_local_physical_experts * rank + local_physical_expert_index + + def global_physical_to_local_physical(self, global_physical_expert_index: int): + return global_physical_expert_index % self.num_local_physical_experts + def _create_vanilla_physical_to_logical_map(num_layers: int, num_physical_experts: int): return torch.arange(0, num_physical_experts).repeat(num_layers, 1) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index af998956271..f5328e1b1d1 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -24,7 +24,6 @@ import torch import torch.distributed as dist - from sglang.srt.configs.device_config import DeviceConfig from sglang.srt.configs.load_config import LoadConfig from sglang.srt.configs.model_config import AttentionArch, ModelConfig @@ -177,7 +176,7 @@ def __init__( set_global_expert_location_metadata(expert_location_metadata) # CPU offload - set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3)) + set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024 ** 3)) # Get memory before model loading min_per_gpu_memory = self.init_torch_distributed() @@ -924,7 +923,7 @@ def init_double_sparsity_channel_config(self, selected_channel): key = "model.layers." + str(i) + ".self_attn" + selected_channel self.sorted_channels.append( torch.tensor(channel_config[key])[ - :, : self.server_args.ds_heavy_channel_num + :, : self.server_args.ds_heavy_channel_num ] .contiguous() .cuda() From 6948dab6de110c77de0437324a1e08131eef16c1 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 16:16:00 +0800 Subject: [PATCH 0301/1089] fmt --- python/sglang/srt/managers/expert_location.py | 5 +-- python/sglang/srt/managers/schedule_batch.py | 36 +++++++++---------- .../sglang/srt/model_executor/model_runner.py | 9 +++-- 3 files changed, 27 insertions(+), 23 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index e4d36a75882..5972d294a32 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -1,6 +1,7 @@ from dataclasses import dataclass import torch + from sglang.srt.configs.model_config import ModelConfig from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.model_loader import get_model_architecture @@ -27,7 +28,7 @@ def from_model_config(model_config: ModelConfig): def init_new(num_layers: int, num_logical_experts: int): # TODO handle more complex cases like duplicating experts on different GPUs num_local_physical_experts = ( - num_logical_experts // get_tensor_model_parallel_world_size() + num_logical_experts // get_tensor_model_parallel_world_size() ) num_physical_experts = num_logical_experts @@ -62,7 +63,7 @@ def update(self, other: "ExpertLocationMetadata"): # -------------------------------- usage ------------------------------------ def local_physical_to_global_physical( - self, rank: int, local_physical_expert_index: int + self, rank: int, local_physical_expert_index: int ): return self.num_local_physical_experts * rank + local_physical_expert_index diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 8974cf976bd..c1267f597fa 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -620,7 +620,7 @@ def init_incremental_detokenize(self): ) all_ids = self.origin_input_ids_unpadded + self.output_ids - return all_ids[self.surr_offset:], self.read_offset - self.surr_offset + return all_ids[self.surr_offset :], self.read_offset - self.surr_offset def check_finished(self): if self.finished(): @@ -661,7 +661,7 @@ def check_finished(self): # Check stop strings if len(self.sampling_params.stop_strs) > 0: tail_str = self.tokenizer.decode( - self.output_ids[-(self.sampling_params.stop_str_max_len + 1):] + self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :] ) for stop_str in self.sampling_params.stop_strs: @@ -957,15 +957,15 @@ def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]) # NOTE: the encoder part should be considered as a whole assert len(req.prefix_indices) == 0 input_ids[i] = input_ids[i][encoder_len:] - encoder_out_cache_loc.append(self.out_cache_loc[pt: pt + encoder_len]) + encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len]) decoder_out_cache_loc.append( - self.out_cache_loc[pt + encoder_len: pt + req.extend_input_len] + self.out_cache_loc[pt + encoder_len : pt + req.extend_input_len] ) self.extend_lens[i] -= encoder_len self.extend_num_tokens -= encoder_len else: decoder_out_cache_loc.append( - self.out_cache_loc[pt: pt + req.extend_input_len] + self.out_cache_loc[pt : pt + req.extend_input_len] ) self.prefix_lens[i] -= encoder_len @@ -1004,7 +1004,7 @@ def prepare_for_extend(self): # Init tensors reqs = self.reqs - input_ids = [r.fill_ids[len(r.prefix_indices):] for r in reqs] + input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs] extend_num_tokens = sum(len(ids) for ids in input_ids) seq_lens = [len(r.fill_ids) for r in reqs] prefix_lens = [len(r.prefix_indices) for r in reqs] @@ -1079,8 +1079,8 @@ def prepare_for_extend(self): global_start_idx = req.logprob_start_len logprob_token_ids = req.origin_input_ids[ - global_start_idx + 1: global_end_idx + 1 - ] + global_start_idx + 1 : global_end_idx + 1 + ] extend_input_logprob_token_ids.extend(logprob_token_ids) # We will need req.extend_input_len - req.extend_logprob_start_len number of @@ -1154,7 +1154,7 @@ def prepare_for_extend(self): for i in range(bs): self.req_to_token_pool.write( (req_pool_indices[i], slice(prefix_lens[i], seq_lens[i])), - out_cache_loc[pt: pt + extend_lens[i]], + out_cache_loc[pt : pt + extend_lens[i]], ) pt += extend_lens[i] @@ -1271,18 +1271,18 @@ def get_required_tokens(num_reqs: int): if isinstance(self.tree_cache, ChunkCache): # ChunkCache does not have eviction token_indices = self.req_to_token_pool.req_to_token[ - req.req_pool_idx, : seq_lens_cpu[idx] - ] + req.req_pool_idx, : seq_lens_cpu[idx] + ] self.token_to_kv_pool_allocator.free(token_indices) self.req_to_token_pool.free(req.req_pool_idx) else: # TODO: apply more fine-grained retraction last_uncached_pos = ( - len(req.prefix_indices) // server_args.page_size - ) * server_args.page_size + len(req.prefix_indices) // server_args.page_size + ) * server_args.page_size token_indices = self.req_to_token_pool.req_to_token[ - req.req_pool_idx, last_uncached_pos: seq_lens_cpu[idx] - ] + req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx] + ] self.token_to_kv_pool_allocator.free(token_indices) self.req_to_token_pool.free(req.req_pool_idx) @@ -1306,8 +1306,8 @@ def get_required_tokens(num_reqs: int): total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs) new_estimate_ratio = ( - total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs) - ) / total_max_new_tokens + total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs) + ) / total_max_new_tokens new_estimate_ratio = min(1.0, new_estimate_ratio) return retracted_reqs, new_estimate_ratio @@ -1404,7 +1404,7 @@ def filter_batch( i for i in range(len(self.reqs)) if not self.reqs[i].finished() - and self.reqs[i] is not chunked_req_to_exclude + and self.reqs[i] is not chunked_req_to_exclude ] if keep_indices is None or len(keep_indices) == 0: diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index f5328e1b1d1..a88c13094fa 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -24,6 +24,7 @@ import torch import torch.distributed as dist + from sglang.srt.configs.device_config import DeviceConfig from sglang.srt.configs.load_config import LoadConfig from sglang.srt.configs.model_config import AttentionArch, ModelConfig @@ -47,7 +48,9 @@ from sglang.srt.managers.expert_distribution import expert_distribution_recorder from sglang.srt.managers.expert_location import ExpertLocationMetadata from sglang.srt.managers.schedule_batch import ( - global_server_args_dict, set_global_expert_location_metadata, get_global_expert_location_metadata, + get_global_expert_location_metadata, + global_server_args_dict, + set_global_expert_location_metadata, ) from sglang.srt.mem_cache.memory_pool import ( DoubleSparseTokenToKVPool, @@ -176,7 +179,7 @@ def __init__( set_global_expert_location_metadata(expert_location_metadata) # CPU offload - set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024 ** 3)) + set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3)) # Get memory before model loading min_per_gpu_memory = self.init_torch_distributed() @@ -923,7 +926,7 @@ def init_double_sparsity_channel_config(self, selected_channel): key = "model.layers." + str(i) + ".self_attn" + selected_channel self.sorted_channels.append( torch.tensor(channel_config[key])[ - :, : self.server_args.ds_heavy_channel_num + :, : self.server_args.ds_heavy_channel_num ] .contiguous() .cuda() From 351bf69cc238c86b52c505bde97570b8bbb35932 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 16:16:33 +0800 Subject: [PATCH 0302/1089] more --- python/sglang/srt/managers/expert_location.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 5972d294a32..4afe1eaa820 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -1,7 +1,6 @@ from dataclasses import dataclass import torch - from sglang.srt.configs.model_config import ModelConfig from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.model_loader import get_model_architecture @@ -12,8 +11,9 @@ class ExpertLocationMetadata: num_layers: int num_local_physical_experts: int num_logical_experts: int - # will have a `logical_to_physical_map` later physical_to_logical_map: torch.Tensor # (layers, num_physical_experts) + logical_to_physical_map: torch.Tensor # (layers, num_logical_experts, X) + chosen_logical_to_physical_map: torch.Tensor # (layers, num_logical_experts) # -------------------------------- construction and mutation ------------------------------------ @@ -28,7 +28,7 @@ def from_model_config(model_config: ModelConfig): def init_new(num_layers: int, num_logical_experts: int): # TODO handle more complex cases like duplicating experts on different GPUs num_local_physical_experts = ( - num_logical_experts // get_tensor_model_parallel_world_size() + num_logical_experts // get_tensor_model_parallel_world_size() ) num_physical_experts = num_logical_experts @@ -63,13 +63,16 @@ def update(self, other: "ExpertLocationMetadata"): # -------------------------------- usage ------------------------------------ def local_physical_to_global_physical( - self, rank: int, local_physical_expert_index: int + self, rank: int, local_physical_expert_index: int ): return self.num_local_physical_experts * rank + local_physical_expert_index def global_physical_to_local_physical(self, global_physical_expert_index: int): return global_physical_expert_index % self.num_local_physical_experts + def logical_to_global_physical(self, logical_expert_id: int): + return [logical_expert_id] # TODO add a logical_to_physical_map + def _create_vanilla_physical_to_logical_map(num_layers: int, num_physical_experts: int): return torch.arange(0, num_physical_experts).repeat(num_layers, 1) From dfca6932f8792cf7b6453792f70868b24815e2a8 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 16:17:56 +0800 Subject: [PATCH 0303/1089] more --- python/sglang/srt/managers/expert_location.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 4afe1eaa820..1ae23de4df9 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from typing import List import torch from sglang.srt.configs.model_config import ModelConfig @@ -70,8 +71,12 @@ def local_physical_to_global_physical( def global_physical_to_local_physical(self, global_physical_expert_index: int): return global_physical_expert_index % self.num_local_physical_experts - def logical_to_global_physical(self, logical_expert_id: int): - return [logical_expert_id] # TODO add a logical_to_physical_map + def logical_to_global_physical(self, layer_id: int, logical_expert_id: int) -> List[int]: + return [ + physical_expert_id + for physical_expert_id in self.logical_to_physical_map[layer_id, logical_expert_id].tolist() + if physical_expert_id != -1 + ] def _create_vanilla_physical_to_logical_map(num_layers: int, num_physical_experts: int): From e5e955e08c8c7eeae2e9e8578969ce4ddda2666b Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 16:18:29 +0800 Subject: [PATCH 0304/1089] more --- python/sglang/srt/managers/expert_location.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 1ae23de4df9..064cbc24f78 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -13,8 +13,8 @@ class ExpertLocationMetadata: num_local_physical_experts: int num_logical_experts: int physical_to_logical_map: torch.Tensor # (layers, num_physical_experts) - logical_to_physical_map: torch.Tensor # (layers, num_logical_experts, X) - chosen_logical_to_physical_map: torch.Tensor # (layers, num_logical_experts) + logical_to_all_physical_map: torch.Tensor # (layers, num_logical_experts, X) + logical_to_chosen_physical_map: torch.Tensor # (layers, num_logical_experts) # -------------------------------- construction and mutation ------------------------------------ @@ -71,10 +71,10 @@ def local_physical_to_global_physical( def global_physical_to_local_physical(self, global_physical_expert_index: int): return global_physical_expert_index % self.num_local_physical_experts - def logical_to_global_physical(self, layer_id: int, logical_expert_id: int) -> List[int]: + def logical_to_all_physical(self, layer_id: int, logical_expert_id: int) -> List[int]: return [ physical_expert_id - for physical_expert_id in self.logical_to_physical_map[layer_id, logical_expert_id].tolist() + for physical_expert_id in self.logical_to_all_physical_map[layer_id, logical_expert_id].tolist() if physical_expert_id != -1 ] From cae8461d6e813df5ab7fa925a0a8daa3992e614a Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 16:18:48 +0800 Subject: [PATCH 0305/1089] more --- python/sglang/srt/managers/expert_distribution.py | 9 ++++----- python/sglang/srt/managers/expert_location.py | 4 ++-- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 28b71f059de..b3ab4227f07 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -5,7 +5,6 @@ from typing import Any, List, Optional, Type import torch - from sglang.srt.managers.expert_location import ExpertLocationMetadata from sglang.srt.server_args import ServerArgs from sglang.srt.utils import Withable, get_bool_env_var @@ -188,12 +187,12 @@ def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor): torch.cuda.synchronize() num_recv_tokens_per_expert_list = [ - 0 - ] * self._expert_location_metadata.num_local_physical_experts + 0 + ] * self._expert_location_metadata.num_local_physical_experts for token_record in topk_ids_list: for global_physical_expert_idx in token_record: local_physical_expert_idx = ( - self._expert_location_metadata.global_physical_to_local_physical( + self._expert_location_metadata.physical_to_local_physical( global_physical_expert_idx ) ) @@ -334,7 +333,7 @@ def postprocess_dumps( expert_location_metadata.num_local_physical_experts ): global_physical_expert_index = ( - expert_location_metadata.local_physical_to_global_physical( + expert_location_metadata.local_physical_to_physical( rank=physical_dump["rank"], local_physical_expert_index=local_physical_expert_index, ) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 064cbc24f78..171cc2183e7 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -63,12 +63,12 @@ def update(self, other: "ExpertLocationMetadata"): # -------------------------------- usage ------------------------------------ - def local_physical_to_global_physical( + def local_physical_to_physical( self, rank: int, local_physical_expert_index: int ): return self.num_local_physical_experts * rank + local_physical_expert_index - def global_physical_to_local_physical(self, global_physical_expert_index: int): + def physical_to_local_physical(self, global_physical_expert_index: int): return global_physical_expert_index % self.num_local_physical_experts def logical_to_all_physical(self, layer_id: int, logical_expert_id: int) -> List[int]: From d38fda2c69a885aaa8f61ecf0880a1c8bc6fb506 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 16:22:16 +0800 Subject: [PATCH 0306/1089] more --- python/sglang/srt/managers/expert_location.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 171cc2183e7..150f4e66fe3 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -14,7 +14,7 @@ class ExpertLocationMetadata: num_logical_experts: int physical_to_logical_map: torch.Tensor # (layers, num_physical_experts) logical_to_all_physical_map: torch.Tensor # (layers, num_logical_experts, X) - logical_to_chosen_physical_map: torch.Tensor # (layers, num_logical_experts) + logical_to_rank_chosen_physical_map: torch.Tensor # (num_gpus, layers, num_logical_experts) # -------------------------------- construction and mutation ------------------------------------ From 1f04f7186736e845e809de6529695affacd46ae8 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 16:22:30 +0800 Subject: [PATCH 0307/1089] more --- python/sglang/srt/managers/expert_location.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 150f4e66fe3..2f168e64a72 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -57,6 +57,8 @@ def update(self, other: "ExpertLocationMetadata"): for field in [ "physical_to_logical_map", + "logical_to_all_physical_map", + "logical_to_rank_chosen_physical_map", ]: # Cannot update address to avoid breaking CUDA graph getattr(self, field)[...] = getattr(other, field) From ec131242b0e210d4be025b93d66740164eaeb367 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 16:22:49 +0800 Subject: [PATCH 0308/1089] more --- python/sglang/srt/managers/expert_location.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 2f168e64a72..6aa73861d32 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -41,6 +41,8 @@ def init_new(num_layers: int, num_logical_experts: int): num_layers=num_layers, num_physical_experts=num_physical_experts, ), + logical_to_all_physical_map=TODO, + logical_to_rank_chosen_physical_map=TODO, ) @staticmethod From a9cb3fa5776cf250a3e15932827cb920d4b1370b Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 16:25:40 +0800 Subject: [PATCH 0309/1089] more --- python/sglang/srt/managers/expert_location.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 6aa73861d32..97c2f130fce 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -37,10 +37,7 @@ def init_new(num_layers: int, num_logical_experts: int): num_layers=num_layers, num_logical_experts=num_logical_experts, num_local_physical_experts=num_local_physical_experts, - physical_to_logical_map=_create_vanilla_physical_to_logical_map( - num_layers=num_layers, - num_physical_experts=num_physical_experts, - ), + physical_to_logical_map=torch.arange(0, num_physical_experts).repeat(num_layers, 1), logical_to_all_physical_map=TODO, logical_to_rank_chosen_physical_map=TODO, ) @@ -81,7 +78,3 @@ def logical_to_all_physical(self, layer_id: int, logical_expert_id: int) -> List for physical_expert_id in self.logical_to_all_physical_map[layer_id, logical_expert_id].tolist() if physical_expert_id != -1 ] - - -def _create_vanilla_physical_to_logical_map(num_layers: int, num_physical_experts: int): - return torch.arange(0, num_physical_experts).repeat(num_layers, 1) From 66a6702443e24aaca7946dc895bd35fa9565d01e Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 16:28:02 +0800 Subject: [PATCH 0310/1089] more --- python/sglang/srt/managers/expert_location.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 97c2f130fce..f4268855676 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -38,8 +38,8 @@ def init_new(num_layers: int, num_logical_experts: int): num_logical_experts=num_logical_experts, num_local_physical_experts=num_local_physical_experts, physical_to_logical_map=torch.arange(0, num_physical_experts).repeat(num_layers, 1), - logical_to_all_physical_map=TODO, - logical_to_rank_chosen_physical_map=TODO, + logical_to_all_physical_map=torch.arange(0, num_physical_experts).repeat(num_layers, 1)[..., None], + logical_to_rank_chosen_physical_map=torch.arange(0, num_physical_experts).repeat(num_layers, 1)[..., None], ) @staticmethod From fa0ee764e1ea839199e5f9ec1bb7de0f4ce25924 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 16:48:11 +0800 Subject: [PATCH 0311/1089] rm --- python/sglang/srt/managers/expert_location.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index f4268855676..4631adc620d 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from typing import List import torch from sglang.srt.configs.model_config import ModelConfig @@ -13,8 +12,6 @@ class ExpertLocationMetadata: num_local_physical_experts: int num_logical_experts: int physical_to_logical_map: torch.Tensor # (layers, num_physical_experts) - logical_to_all_physical_map: torch.Tensor # (layers, num_logical_experts, X) - logical_to_rank_chosen_physical_map: torch.Tensor # (num_gpus, layers, num_logical_experts) # -------------------------------- construction and mutation ------------------------------------ @@ -38,8 +35,6 @@ def init_new(num_layers: int, num_logical_experts: int): num_logical_experts=num_logical_experts, num_local_physical_experts=num_local_physical_experts, physical_to_logical_map=torch.arange(0, num_physical_experts).repeat(num_layers, 1), - logical_to_all_physical_map=torch.arange(0, num_physical_experts).repeat(num_layers, 1)[..., None], - logical_to_rank_chosen_physical_map=torch.arange(0, num_physical_experts).repeat(num_layers, 1)[..., None], ) @staticmethod @@ -56,8 +51,6 @@ def update(self, other: "ExpertLocationMetadata"): for field in [ "physical_to_logical_map", - "logical_to_all_physical_map", - "logical_to_rank_chosen_physical_map", ]: # Cannot update address to avoid breaking CUDA graph getattr(self, field)[...] = getattr(other, field) @@ -71,10 +64,3 @@ def local_physical_to_physical( def physical_to_local_physical(self, global_physical_expert_index: int): return global_physical_expert_index % self.num_local_physical_experts - - def logical_to_all_physical(self, layer_id: int, logical_expert_id: int) -> List[int]: - return [ - physical_expert_id - for physical_expert_id in self.logical_to_all_physical_map[layer_id, logical_expert_id].tolist() - if physical_expert_id != -1 - ] From 52e820ccaff0e4a70bab61701f2267d0117d3a71 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 16:49:10 +0800 Subject: [PATCH 0312/1089] rm --- .../srt/managers/expert_distribution.py | 5 ++-- python/sglang/srt/managers/expert_location.py | 25 +++++-------------- 2 files changed, 9 insertions(+), 21 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index b3ab4227f07..fb31d415d76 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -5,6 +5,7 @@ from typing import Any, List, Optional, Type import torch + from sglang.srt.managers.expert_location import ExpertLocationMetadata from sglang.srt.server_args import ServerArgs from sglang.srt.utils import Withable, get_bool_env_var @@ -187,8 +188,8 @@ def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor): torch.cuda.synchronize() num_recv_tokens_per_expert_list = [ - 0 - ] * self._expert_location_metadata.num_local_physical_experts + 0 + ] * self._expert_location_metadata.num_local_physical_experts for token_record in topk_ids_list: for global_physical_expert_idx in token_record: local_physical_expert_idx = ( diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 4631adc620d..50012ba583b 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -1,6 +1,7 @@ from dataclasses import dataclass import torch + from sglang.srt.configs.model_config import ModelConfig from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.model_loader import get_model_architecture @@ -26,7 +27,7 @@ def from_model_config(model_config: ModelConfig): def init_new(num_layers: int, num_logical_experts: int): # TODO handle more complex cases like duplicating experts on different GPUs num_local_physical_experts = ( - num_logical_experts // get_tensor_model_parallel_world_size() + num_logical_experts // get_tensor_model_parallel_world_size() ) num_physical_experts = num_logical_experts @@ -34,32 +35,18 @@ def init_new(num_layers: int, num_logical_experts: int): num_layers=num_layers, num_logical_experts=num_logical_experts, num_local_physical_experts=num_local_physical_experts, - physical_to_logical_map=torch.arange(0, num_physical_experts).repeat(num_layers, 1), + physical_to_logical_map=torch.arange(0, num_physical_experts).repeat( + num_layers, 1 + ), ) @staticmethod def _init_dummy(): return ExpertLocationMetadata.init_new(num_layers=1, num_logical_experts=1) - def update(self, other: "ExpertLocationMetadata"): - for field in [ - "num_layers", - "num_local_physical_experts", - "num_logical_experts", - ]: - assert getattr(self, field) == getattr(other, field) - - for field in [ - "physical_to_logical_map", - ]: - # Cannot update address to avoid breaking CUDA graph - getattr(self, field)[...] = getattr(other, field) - # -------------------------------- usage ------------------------------------ - def local_physical_to_physical( - self, rank: int, local_physical_expert_index: int - ): + def local_physical_to_physical(self, rank: int, local_physical_expert_index: int): return self.num_local_physical_experts * rank + local_physical_expert_index def physical_to_local_physical(self, global_physical_expert_index: int): From 5767e79c951a0f70e56811c6d7cd474f646cb11a Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 16:50:06 +0800 Subject: [PATCH 0313/1089] cherry pick --- python/sglang/srt/managers/expert_location.py | 103 ++++++++---------- 1 file changed, 46 insertions(+), 57 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index bddfcaa3bdd..a00ce8ffbf6 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from typing import List import torch @@ -9,94 +10,82 @@ @dataclass class ExpertLocationMetadata: - is_dummy: bool num_layers: int num_local_physical_experts: int num_logical_experts: int - # will have a `logical_to_physical_map` later physical_to_logical_map: torch.Tensor # (layers, num_physical_experts) - chosen_logical_to_physical_map: torch.Tensor # (layers, num_logical_experts) + logical_to_all_physical_map: torch.Tensor # (layers, num_logical_experts, X) + logical_to_rank_chosen_physical_map: ( + torch.Tensor + ) # (num_gpus, layers, num_logical_experts) + + # -------------------------------- construction and mutation ------------------------------------ @staticmethod def from_model_config(model_config: ModelConfig): model_class, _ = get_model_architecture(model_config) if hasattr(model_class, "get_expert_location_metadata"): return model_class.get_expert_location_metadata(model_config.hf_config) - return ExpertLocationMetadata.init_dummy() + return ExpertLocationMetadata._init_dummy() @staticmethod - def init_new(num_layers: int, num_logical_experts: int, is_dummy: bool = False): + def init_new(num_layers: int, num_logical_experts: int): # TODO handle more complex cases like duplicating experts on different GPUs num_local_physical_experts = ( num_logical_experts // get_tensor_model_parallel_world_size() ) num_physical_experts = num_logical_experts - output = ExpertLocationMetadata( - is_dummy=is_dummy, + return ExpertLocationMetadata( num_layers=num_layers, num_logical_experts=num_logical_experts, num_local_physical_experts=num_local_physical_experts, - physical_to_logical_map=_create_vanilla_physical_to_logical_map( - num_layers=num_layers, - num_physical_experts=num_physical_experts, + physical_to_logical_map=torch.arange(0, num_physical_experts).repeat( + num_layers, 1 ), - chosen_logical_to_physical_map=TODO, + logical_to_all_physical_map=torch.arange(0, num_physical_experts).repeat( + num_layers, 1 + )[..., None], + logical_to_rank_chosen_physical_map=torch.arange( + 0, num_physical_experts + ).repeat(num_layers, 1)[..., None], ) - output._rebuild() - return output @staticmethod - def init_dummy(): - return ExpertLocationMetadata.init_new( - num_layers=1, num_logical_experts=1, is_dummy=True - ) - - def local_physical_to_global_physical( - self, rank: int, local_physical_expert_index: int - ): - return self.num_local_physical_experts * rank + local_physical_expert_index - - def global_physical_to_local_physical(self, global_physical_expert_index: int): - return global_physical_expert_index % self.num_local_physical_experts - - def logical_to_global_physical(self, logical_expert_id: int): - return [logical_expert_id] # TODO add a logical_to_physical_map - - def _rebuild(self): - self.chosen_logical_to_physical_map[...] = TODO + def _init_dummy(): + return ExpertLocationMetadata.init_new(num_layers=1, num_logical_experts=1) def update(self, other: "ExpertLocationMetadata"): - if self.is_dummy: - self._update_unconditionally(other) - else: - self._update_partial(other) - self._rebuild() - - def _update_unconditionally(self, other: "ExpertLocationMetadata"): - for field in _UPDATE_FIELDS_TRIVIAL: - setattr(self, field, getattr(other, field)) - for field in _UPDATE_FIELDS_TENSOR: - setattr(self, field, getattr(other, field).detach().clone()) - - def _update_partial(self, other: "ExpertLocationMetadata"): - for field in _UPDATE_FIELDS_TRIVIAL: + for field in [ + "num_layers", + "num_local_physical_experts", + "num_logical_experts", + ]: assert getattr(self, field) == getattr(other, field) - for field in _UPDATE_FIELDS_TENSOR: + + for field in [ + "physical_to_logical_map", + "logical_to_all_physical_map", + "logical_to_rank_chosen_physical_map", + ]: # Cannot update address to avoid breaking CUDA graph getattr(self, field)[...] = getattr(other, field) + # -------------------------------- usage ------------------------------------ -_UPDATE_FIELDS_TRIVIAL = [ - "is_dummy", - "num_layers", - "num_local_physical_experts", - "num_logical_experts", -] -_UPDATE_FIELDS_TENSOR = [ - "physical_to_logical_map", -] + def local_physical_to_physical(self, rank: int, local_physical_expert_index: int): + return self.num_local_physical_experts * rank + local_physical_expert_index + def physical_to_local_physical(self, global_physical_expert_index: int): + return global_physical_expert_index % self.num_local_physical_experts -def _create_vanilla_physical_to_logical_map(num_layers: int, num_physical_experts: int): - return torch.arange(0, num_physical_experts).repeat(num_layers, 1) + def logical_to_all_physical( + self, layer_id: int, logical_expert_id: int + ) -> List[int]: + return [ + physical_expert_id + for physical_expert_id in self.logical_to_all_physical_map[ + layer_id, logical_expert_id + ].tolist() + if physical_expert_id != -1 + ] From f7be53288661a8481597cce15dd36095671fe177 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 16:51:21 +0800 Subject: [PATCH 0314/1089] rm --- python/sglang/srt/managers/expert_location.py | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index a00ce8ffbf6..a6e565c4107 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -55,22 +55,6 @@ def init_new(num_layers: int, num_logical_experts: int): def _init_dummy(): return ExpertLocationMetadata.init_new(num_layers=1, num_logical_experts=1) - def update(self, other: "ExpertLocationMetadata"): - for field in [ - "num_layers", - "num_local_physical_experts", - "num_logical_experts", - ]: - assert getattr(self, field) == getattr(other, field) - - for field in [ - "physical_to_logical_map", - "logical_to_all_physical_map", - "logical_to_rank_chosen_physical_map", - ]: - # Cannot update address to avoid breaking CUDA graph - getattr(self, field)[...] = getattr(other, field) - # -------------------------------- usage ------------------------------------ def local_physical_to_physical(self, rank: int, local_physical_expert_index: int): From 8767194beca7d6b185f87ab2bdc22206116ca4bb Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 16:55:31 +0800 Subject: [PATCH 0315/1089] more --- python/sglang/srt/server_args.py | 105 ++++++++++++++++--------------- 1 file changed, 56 insertions(+), 49 deletions(-) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 28539dcee44..09d9f3499e6 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -162,6 +162,7 @@ class ServerArgs: enable_ep_moe: bool = False enable_deepep_moe: bool = False deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto" + ep_num_redundant_experts: int = 0 enable_torch_compile: bool = False torch_compile_max_bs: int = 32 cuda_graph_max_bs: Optional[int] = None @@ -298,7 +299,7 @@ def __post_init__(self): if self.enable_dp_attention: self.schedule_conservativeness = self.schedule_conservativeness * 0.3 assert ( - self.dp_size > 1 + self.dp_size > 1 ), "Please set a dp-size > 1. You can use 1 < dp-size <= tp-size " assert self.tp_size % self.dp_size == 0 self.chunked_prefill_size = self.chunked_prefill_size // self.dp_size @@ -327,8 +328,8 @@ def __post_init__(self): self.speculative_algorithm = "EAGLE" if ( - self.speculative_algorithm == "EAGLE" - or self.speculative_algorithm == "EAGLE3" + self.speculative_algorithm == "EAGLE" + or self.speculative_algorithm == "EAGLE3" ): if self.max_running_requests is None: self.max_running_requests = 48 @@ -341,8 +342,8 @@ def __post_init__(self): # Auto choose parameters if self.speculative_num_steps is None: assert ( - self.speculative_eagle_topk is None - and self.speculative_num_draft_tokens is None + self.speculative_eagle_topk is None + and self.speculative_num_draft_tokens is None ) ( self.speculative_num_steps, @@ -360,7 +361,7 @@ def __post_init__(self): # GGUF if ( - self.load_format == "auto" or self.load_format == "gguf" + self.load_format == "auto" or self.load_format == "gguf" ) and check_gguf_file(self.model_path): self.quantization = self.load_format = "gguf" @@ -414,8 +415,8 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.tokenizer_mode, choices=["auto", "slow"], help="Tokenizer mode. 'auto' will use the fast " - "tokenizer if available, and 'slow' will " - "always use the slow tokenizer.", + "tokenizer if available, and 'slow' will " + "always use the slow tokenizer.", ) parser.add_argument( "--skip-tokenizer-init", @@ -439,21 +440,21 @@ def add_cli_args(parser: argparse.ArgumentParser): "remote", ], help="The format of the model weights to load. " - '"auto" will try to load the weights in the safetensors format ' - "and fall back to the pytorch bin format if safetensors format " - "is not available. " - '"pt" will load the weights in the pytorch bin format. ' - '"safetensors" will load the weights in the safetensors format. ' - '"npcache" will load the weights in pytorch format and store ' - "a numpy cache to speed up the loading. " - '"dummy" will initialize the weights with random values, ' - "which is mainly for profiling." - '"gguf" will load the weights in the gguf format. ' - '"bitsandbytes" will load the weights using bitsandbytes ' - "quantization." - '"layered" loads weights layer by layer so that one can quantize a ' - "layer before loading another to make the peak memory envelope " - "smaller.", + '"auto" will try to load the weights in the safetensors format ' + "and fall back to the pytorch bin format if safetensors format " + "is not available. " + '"pt" will load the weights in the pytorch bin format. ' + '"safetensors" will load the weights in the safetensors format. ' + '"npcache" will load the weights in pytorch format and store ' + "a numpy cache to speed up the loading. " + '"dummy" will initialize the weights with random values, ' + "which is mainly for profiling." + '"gguf" will load the weights in the gguf format. ' + '"bitsandbytes" will load the weights using bitsandbytes ' + "quantization." + '"layered" loads weights layer by layer so that one can quantize a ' + "layer before loading another to make the peak memory envelope " + "smaller.", ) parser.add_argument( "--trust-remote-code", @@ -466,13 +467,13 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.dtype, choices=["auto", "half", "float16", "bfloat16", "float", "float32"], help="Data type for model weights and activations.\n\n" - '* "auto" will use FP16 precision for FP32 and FP16 models, and ' - "BF16 precision for BF16 models.\n" - '* "half" for FP16. Recommended for AWQ quantization.\n' - '* "float16" is the same as "half".\n' - '* "bfloat16" for a balance between precision and range.\n' - '* "float" is shorthand for FP32 precision.\n' - '* "float32" for FP32 precision.', + '* "auto" will use FP16 precision for FP32 and FP16 models, and ' + "BF16 precision for BF16 models.\n" + '* "half" for FP16. Recommended for AWQ quantization.\n' + '* "float16" is the same as "half".\n' + '* "bfloat16" for a balance between precision and range.\n' + '* "float" is shorthand for FP32 precision.\n' + '* "float32" for FP32 precision.', ) parser.add_argument( "--kv-cache-dtype", @@ -507,9 +508,9 @@ def add_cli_args(parser: argparse.ArgumentParser): type=nullable_str, default=None, help="Path to the JSON file containing the KV cache " - "scaling factors. This should generally be supplied, when " - "KV cache dtype is FP8. Otherwise, KV cache scaling factors " - "default to 1.0, which may cause accuracy issues. ", + "scaling factors. This should generally be supplied, when " + "KV cache dtype is FP8. Otherwise, KV cache scaling factors " + "default to 1.0, which may cause accuracy issues. ", ) parser.add_argument( "--context-length", @@ -551,8 +552,8 @@ def add_cli_args(parser: argparse.ArgumentParser): type=str, default=None, help="The specific model version to use. It can be a branch " - "name, a tag name, or a commit id. If unspecified, will use " - "the default version.", + "name, a tag name, or a commit id. If unspecified, will use " + "the default version.", ) # Memory and scheduling parser.add_argument( @@ -572,7 +573,7 @@ def add_cli_args(parser: argparse.ArgumentParser): type=int, default=ServerArgs.max_total_tokens, help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. " - "This option is typically used for development and debugging purposes.", + "This option is typically used for development and debugging purposes.", ) parser.add_argument( "--chunked-prefill-size", @@ -1037,7 +1038,7 @@ def add_cli_args(parser: argparse.ArgumentParser): "--triton-attention-reduce-in-fp32", action="store_true", help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16." - "This only affects Triton attention kernels.", + "This only affects Triton attention kernels.", ) parser.add_argument( "--triton-attention-num-kv-splits", @@ -1050,8 +1051,8 @@ def add_cli_args(parser: argparse.ArgumentParser): type=int, default=ServerArgs.num_continuous_decode_steps, help="Run multiple continuous decoding steps to reduce scheduling overhead. " - "This can potentially increase throughput but may also increase time-to-first-token latency. " - "The default value is 1, meaning only run one decoding step at a time.", + "This can potentially increase throughput but may also increase time-to-first-token latency. " + "The default value is 1, meaning only run one decoding step at a time.", ) parser.add_argument( "--delete-ckpt-after-loading", @@ -1097,6 +1098,12 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Enabling DeepEP MoE implementation for EP MoE.", ) + parser.add_argument( + "--ep-num-redundant-experts", + type=int, + default=ServerArgs.ep_num_redundant_experts, + help="Allocate this number of redundant experts in expert parallel.", + ) parser.add_argument( "--deepep-mode", type=str, @@ -1110,7 +1117,7 @@ def add_cli_args(parser: argparse.ArgumentParser): type=int, default=0, help="The number of shared_experts need to be replica to fuse with normal experts in deepseek v3/r1 " - "we use tp_size by default.", + "we use tp_size by default.", ) parser.add_argument( "--disable-shared-experts-fusion", @@ -1124,7 +1131,7 @@ def add_cli_args(parser: argparse.ArgumentParser): type=str, required=False, help="Specify custom warmup functions (csv) to run before server starts eg. --warmups=warmup_name1,warmup_name2 " - "will run the functions `warmup_name1` and `warmup_name2` specified in warmup.py before the server starts listening for requests", + "will run the functions `warmup_name1` and `warmup_name2` specified in warmup.py before the server starts listening for requests", ) # Debug tensor dumps @@ -1178,16 +1185,16 @@ def url(self): def check_server_args(self): assert ( - self.tp_size % self.nnodes == 0 + self.tp_size % self.nnodes == 0 ), "tp_size must be divisible by number of nodes" assert not ( - self.dp_size > 1 and self.nnodes != 1 and not self.enable_dp_attention + self.dp_size > 1 and self.nnodes != 1 and not self.enable_dp_attention ), "multi-node data parallel is not supported unless dp attention!" assert ( - self.max_loras_per_batch > 0 - # FIXME - and (self.lora_paths is None or self.disable_cuda_graph) - and (self.lora_paths is None or self.disable_radix_cache) + self.max_loras_per_batch > 0 + # FIXME + and (self.lora_paths is None or self.disable_cuda_graph) + and (self.lora_paths is None or self.disable_radix_cache) ), "compatibility of lora and cuda graph and radix attention is in progress" assert self.base_gpu_id >= 0, "base_gpu_id must be non-negative" assert self.gpu_id_step >= 1, "gpu_id_step must be positive" @@ -1270,14 +1277,14 @@ def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs": dist_init_addr = server_args.dist_init_addr.split(":") assert ( - len(dist_init_addr) == 2 + len(dist_init_addr) == 2 ), "please provide --dist-init-addr as host:port of head node" dist_init_host, dist_init_port = dist_init_addr port_base = int(dist_init_port) + 1 if dp_rank is None: scheduler_input_port = ( - port_base + 3 + port_base + 3 ) # TokenizerManager to DataParallelController else: scheduler_input_port = port_base + 3 + 1 + dp_rank From afe433dc316516977b8282f2048ec021eb145f36 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 16:55:57 +0800 Subject: [PATCH 0316/1089] more --- python/sglang/srt/server_args.py | 31 +++++++++++++++++-------------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 09d9f3499e6..6c9a7813746 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -299,7 +299,7 @@ def __post_init__(self): if self.enable_dp_attention: self.schedule_conservativeness = self.schedule_conservativeness * 0.3 assert ( - self.dp_size > 1 + self.dp_size > 1 ), "Please set a dp-size > 1. You can use 1 < dp-size <= tp-size " assert self.tp_size % self.dp_size == 0 self.chunked_prefill_size = self.chunked_prefill_size // self.dp_size @@ -322,14 +322,17 @@ def __post_init__(self): f"DeepEP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]." ) + if self.ep_num_redundant_experts > 0: + assert self.enable_deepep_moe, "ep_num_redundant_experts currently requires DeepEP MoE" + # Speculative Decoding if self.speculative_algorithm == "NEXTN": # NEXTN shares the same implementation of EAGLE self.speculative_algorithm = "EAGLE" if ( - self.speculative_algorithm == "EAGLE" - or self.speculative_algorithm == "EAGLE3" + self.speculative_algorithm == "EAGLE" + or self.speculative_algorithm == "EAGLE3" ): if self.max_running_requests is None: self.max_running_requests = 48 @@ -342,8 +345,8 @@ def __post_init__(self): # Auto choose parameters if self.speculative_num_steps is None: assert ( - self.speculative_eagle_topk is None - and self.speculative_num_draft_tokens is None + self.speculative_eagle_topk is None + and self.speculative_num_draft_tokens is None ) ( self.speculative_num_steps, @@ -361,7 +364,7 @@ def __post_init__(self): # GGUF if ( - self.load_format == "auto" or self.load_format == "gguf" + self.load_format == "auto" or self.load_format == "gguf" ) and check_gguf_file(self.model_path): self.quantization = self.load_format = "gguf" @@ -1185,16 +1188,16 @@ def url(self): def check_server_args(self): assert ( - self.tp_size % self.nnodes == 0 + self.tp_size % self.nnodes == 0 ), "tp_size must be divisible by number of nodes" assert not ( - self.dp_size > 1 and self.nnodes != 1 and not self.enable_dp_attention + self.dp_size > 1 and self.nnodes != 1 and not self.enable_dp_attention ), "multi-node data parallel is not supported unless dp attention!" assert ( - self.max_loras_per_batch > 0 - # FIXME - and (self.lora_paths is None or self.disable_cuda_graph) - and (self.lora_paths is None or self.disable_radix_cache) + self.max_loras_per_batch > 0 + # FIXME + and (self.lora_paths is None or self.disable_cuda_graph) + and (self.lora_paths is None or self.disable_radix_cache) ), "compatibility of lora and cuda graph and radix attention is in progress" assert self.base_gpu_id >= 0, "base_gpu_id must be non-negative" assert self.gpu_id_step >= 1, "gpu_id_step must be positive" @@ -1277,14 +1280,14 @@ def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs": dist_init_addr = server_args.dist_init_addr.split(":") assert ( - len(dist_init_addr) == 2 + len(dist_init_addr) == 2 ), "please provide --dist-init-addr as host:port of head node" dist_init_host, dist_init_port = dist_init_addr port_base = int(dist_init_port) + 1 if dp_rank is None: scheduler_input_port = ( - port_base + 3 + port_base + 3 ) # TokenizerManager to DataParallelController else: scheduler_input_port = port_base + 3 + 1 + dp_rank From f8b2c17811ad3690554fb0a30e6999bc67c82a8d Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 16:56:14 +0800 Subject: [PATCH 0317/1089] more --- python/sglang/srt/models/deepseek_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 2380031fc73..00d8b22ae55 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -201,7 +201,7 @@ def __init__( ) self.experts = MoEImpl( - num_experts=config.n_routed_experts + self.n_share_experts_fusion, + num_experts=config.n_routed_experts + self.n_share_experts_fusion + TODO, top_k=config.num_experts_per_tok + min(self.n_share_experts_fusion, 1), hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, From 2006ba4ec206a0736df028c8fa5b7be9a0688e75 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 16:56:29 +0800 Subject: [PATCH 0318/1089] more --- python/sglang/srt/models/deepseek_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 00d8b22ae55..554fa83daf9 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -201,7 +201,7 @@ def __init__( ) self.experts = MoEImpl( - num_experts=config.n_routed_experts + self.n_share_experts_fusion + TODO, + num_experts=config.n_routed_experts + self.n_share_experts_fusion + global_server_args_dict["ep_num_redundant_experts"], top_k=config.num_experts_per_tok + min(self.n_share_experts_fusion, 1), hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, From 494172d0cc88a8da757fea2d44d4e123e60c062f Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 16:57:03 +0800 Subject: [PATCH 0319/1089] more --- python/sglang/srt/managers/schedule_batch.py | 201 +++++++++--------- .../sglang/srt/model_executor/model_runner.py | 6 +- 2 files changed, 104 insertions(+), 103 deletions(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index c1267f597fa..df4be21fb32 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -85,6 +85,7 @@ "chunked_prefill_size": ServerArgs.chunked_prefill_size, "n_share_experts_fusion": ServerArgs.n_share_experts_fusion, "disable_shared_experts_fusion": ServerArgs.disable_shared_experts_fusion, + "ep_num_redundant_experts": ServerArgs.ep_num_redundant_experts, } _global_expert_location_metadata: Optional[ExpertLocationMetadata] = None @@ -270,17 +271,17 @@ def hash_feature(f): def is_audio(self): return ( - self.modality == Modality.AUDIO + self.modality == Modality.AUDIO ) and not MultimodalDataItem.is_empty_list(self.audio_features) def is_image(self): return ( - self.modality == Modality.IMAGE or self.modality == Modality.MULTI_IMAGES + self.modality == Modality.IMAGE or self.modality == Modality.MULTI_IMAGES ) and not MultimodalDataItem.is_empty_list(self.pixel_values) def is_video(self): return ( - self.modality == Modality.VIDEO + self.modality == Modality.VIDEO ) and not MultimodalDataItem.is_empty_list(self.pixel_values) def validate(self): @@ -389,24 +390,24 @@ class Req: """The input and output status of a request.""" def __init__( - self, - rid: str, - origin_input_text: str, - origin_input_ids: Tuple[int], - sampling_params: SamplingParams, - return_logprob: bool = False, - top_logprobs_num: int = 0, - token_ids_logprob: List[int] = None, - stream: bool = False, - origin_input_ids_unpadded: Optional[Tuple[int]] = None, - lora_path: Optional[str] = None, - input_embeds: Optional[List[List[float]]] = None, - session_id: Optional[str] = None, - custom_logit_processor: Optional[str] = None, - return_hidden_states: bool = False, - eos_token_ids: Optional[Set[int]] = None, - bootstrap_host: Optional[str] = None, - bootstrap_room: Optional[int] = None, + self, + rid: str, + origin_input_text: str, + origin_input_ids: Tuple[int], + sampling_params: SamplingParams, + return_logprob: bool = False, + top_logprobs_num: int = 0, + token_ids_logprob: List[int] = None, + stream: bool = False, + origin_input_ids_unpadded: Optional[Tuple[int]] = None, + lora_path: Optional[str] = None, + input_embeds: Optional[List[List[float]]] = None, + session_id: Optional[str] = None, + custom_logit_processor: Optional[str] = None, + return_hidden_states: bool = False, + eos_token_ids: Optional[Set[int]] = None, + bootstrap_host: Optional[str] = None, + bootstrap_room: Optional[int] = None, ): # Input and output info self.rid = rid @@ -572,9 +573,9 @@ def finished(self) -> bool: return self.finished_reason is not None def init_next_round_input( - self, - tree_cache: Optional[BasePrefixCache] = None, - enable_hierarchical_cache=False, + self, + tree_cache: Optional[BasePrefixCache] = None, + enable_hierarchical_cache=False, ): self.fill_ids = self.origin_input_ids + self.output_ids if tree_cache is not None: @@ -620,7 +621,7 @@ def init_incremental_detokenize(self): ) all_ids = self.origin_input_ids_unpadded + self.output_ids - return all_ids[self.surr_offset :], self.read_offset - self.surr_offset + return all_ids[self.surr_offset:], self.read_offset - self.surr_offset def check_finished(self): if self.finished(): @@ -652,7 +653,7 @@ def check_finished(self): matched_eos |= last_token_id == self.tokenizer.eos_token_id if self.tokenizer.additional_stop_token_ids: matched_eos |= ( - last_token_id in self.tokenizer.additional_stop_token_ids + last_token_id in self.tokenizer.additional_stop_token_ids ) if matched_eos: self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id) @@ -661,7 +662,7 @@ def check_finished(self): # Check stop strings if len(self.sampling_params.stop_strs) > 0: tail_str = self.tokenizer.decode( - self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :] + self.output_ids[-(self.sampling_params.stop_str_max_len + 1):] ) for stop_str in self.sampling_params.stop_strs: @@ -777,15 +778,15 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): @classmethod def init_new( - cls, - reqs: List[Req], - req_to_token_pool: ReqToTokenPool, - token_to_kv_pool_allocator: TokenToKVPoolAllocator, - tree_cache: BasePrefixCache, - model_config: ModelConfig, - enable_overlap: bool, - spec_algorithm: SpeculativeAlgorithm, - enable_custom_logit_processor: bool, + cls, + reqs: List[Req], + req_to_token_pool: ReqToTokenPool, + token_to_kv_pool_allocator: TokenToKVPoolAllocator, + tree_cache: BasePrefixCache, + model_config: ModelConfig, + enable_overlap: bool, + spec_algorithm: SpeculativeAlgorithm, + enable_custom_logit_processor: bool, ): return_logprob = any(req.return_logprob for req in reqs) @@ -849,17 +850,17 @@ def alloc_token_slots(self, num_tokens: int, backup_state: bool = False): return out_cache_loc def alloc_paged_token_slots_extend( - self, - prefix_lens: torch.Tensor, - seq_lens: torch.Tensor, - last_loc: torch.Tensor, - extend_num_tokens: int, - backup_state: bool = False, + self, + prefix_lens: torch.Tensor, + seq_lens: torch.Tensor, + last_loc: torch.Tensor, + extend_num_tokens: int, + backup_state: bool = False, ): if ( - self.token_to_kv_pool_allocator.available_size() - < extend_num_tokens - + len(seq_lens) * self.token_to_kv_pool_allocator.page_size + self.token_to_kv_pool_allocator.available_size() + < extend_num_tokens + + len(seq_lens) * self.token_to_kv_pool_allocator.page_size ): if self.tree_cache is not None: self.tree_cache.evict( @@ -890,15 +891,15 @@ def alloc_paged_token_slots_extend( return out_cache_loc def alloc_paged_token_slots_decode( - self, - seq_lens: torch.Tensor, - last_loc: torch.Tensor, - backup_state: bool = False, + self, + seq_lens: torch.Tensor, + last_loc: torch.Tensor, + backup_state: bool = False, ): if self.tree_cache is not None: if ( - self.token_to_kv_pool_allocator.available_size() - < len(seq_lens) * self.token_to_kv_pool_allocator.page_size + self.token_to_kv_pool_allocator.available_size() + < len(seq_lens) * self.token_to_kv_pool_allocator.page_size ): self.tree_cache.evict( len(seq_lens) * self.token_to_kv_pool_allocator.page_size, @@ -957,15 +958,15 @@ def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]) # NOTE: the encoder part should be considered as a whole assert len(req.prefix_indices) == 0 input_ids[i] = input_ids[i][encoder_len:] - encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len]) + encoder_out_cache_loc.append(self.out_cache_loc[pt: pt + encoder_len]) decoder_out_cache_loc.append( - self.out_cache_loc[pt + encoder_len : pt + req.extend_input_len] + self.out_cache_loc[pt + encoder_len: pt + req.extend_input_len] ) self.extend_lens[i] -= encoder_len self.extend_num_tokens -= encoder_len else: decoder_out_cache_loc.append( - self.out_cache_loc[pt : pt + req.extend_input_len] + self.out_cache_loc[pt: pt + req.extend_input_len] ) self.prefix_lens[i] -= encoder_len @@ -1004,7 +1005,7 @@ def prepare_for_extend(self): # Init tensors reqs = self.reqs - input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs] + input_ids = [r.fill_ids[len(r.prefix_indices):] for r in reqs] extend_num_tokens = sum(len(ids) for ids in input_ids) seq_lens = [len(r.fill_ids) for r in reqs] prefix_lens = [len(r.prefix_indices) for r in reqs] @@ -1079,8 +1080,8 @@ def prepare_for_extend(self): global_start_idx = req.logprob_start_len logprob_token_ids = req.origin_input_ids[ - global_start_idx + 1 : global_end_idx + 1 - ] + global_start_idx + 1: global_end_idx + 1 + ] extend_input_logprob_token_ids.extend(logprob_token_ids) # We will need req.extend_input_len - req.extend_logprob_start_len number of @@ -1088,9 +1089,9 @@ def prepare_for_extend(self): extend_input_logprob_token_ids.extend( [0] * ( - req.extend_input_len - - req.extend_logprob_start_len - - len(logprob_token_ids) + req.extend_input_len + - req.extend_logprob_start_len + - len(logprob_token_ids) ) ) @@ -1154,7 +1155,7 @@ def prepare_for_extend(self): for i in range(bs): self.req_to_token_pool.write( (req_pool_indices[i], slice(prefix_lens[i], seq_lens[i])), - out_cache_loc[pt : pt + extend_lens[i]], + out_cache_loc[pt: pt + extend_lens[i]], ) pt += extend_lens[i] @@ -1205,9 +1206,9 @@ def new_page_count_next_decode(self): def check_decode_mem(self, buf_multiplier=1): tokens_required = ( - self.new_page_count_next_decode() - * buf_multiplier - * self.token_to_kv_pool_allocator.page_size + self.new_page_count_next_decode() + * buf_multiplier + * self.token_to_kv_pool_allocator.page_size ) if self.token_to_kv_pool_allocator.available_size() >= tokens_required: @@ -1239,27 +1240,27 @@ def get_required_tokens(num_reqs: int): headroom_for_spec_decode = 0 if server_args.speculative_algorithm: headroom_for_spec_decode += ( - num_reqs - * server_args.speculative_eagle_topk - * server_args.speculative_num_steps - + num_reqs * server_args.speculative_num_draft_tokens + num_reqs + * server_args.speculative_eagle_topk + * server_args.speculative_num_steps + + num_reqs * server_args.speculative_num_draft_tokens ) return ( - num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode + num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode ) retracted_reqs = [] seq_lens_cpu = self.seq_lens.cpu().numpy() first_iter = True while ( - self.token_to_kv_pool_allocator.available_size() - < get_required_tokens(len(sorted_indices)) - or first_iter + self.token_to_kv_pool_allocator.available_size() + < get_required_tokens(len(sorted_indices)) + or first_iter ): if len(sorted_indices) == 1: # Corner case: only one request left assert ( - self.token_to_kv_pool_allocator.available_size() > 0 + self.token_to_kv_pool_allocator.available_size() > 0 ), "No space left for only one request" break @@ -1271,18 +1272,18 @@ def get_required_tokens(num_reqs: int): if isinstance(self.tree_cache, ChunkCache): # ChunkCache does not have eviction token_indices = self.req_to_token_pool.req_to_token[ - req.req_pool_idx, : seq_lens_cpu[idx] - ] + req.req_pool_idx, : seq_lens_cpu[idx] + ] self.token_to_kv_pool_allocator.free(token_indices) self.req_to_token_pool.free(req.req_pool_idx) else: # TODO: apply more fine-grained retraction last_uncached_pos = ( - len(req.prefix_indices) // server_args.page_size - ) * server_args.page_size + len(req.prefix_indices) // server_args.page_size + ) * server_args.page_size token_indices = self.req_to_token_pool.req_to_token[ - req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx] - ] + req.req_pool_idx, last_uncached_pos: seq_lens_cpu[idx] + ] self.token_to_kv_pool_allocator.free(token_indices) self.req_to_token_pool.free(req.req_pool_idx) @@ -1291,8 +1292,8 @@ def get_required_tokens(num_reqs: int): # NOTE(lsyin): we should use the newly evictable memory instantly. residual_size = ( - len(sorted_indices) * global_config.retract_decode_steps - - self.token_to_kv_pool_allocator.available_size() + len(sorted_indices) * global_config.retract_decode_steps + - self.token_to_kv_pool_allocator.available_size() ) residual_size = max(0, residual_size) self.tree_cache.evict(residual_size) @@ -1306,8 +1307,8 @@ def get_required_tokens(num_reqs: int): total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs) new_estimate_ratio = ( - total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs) - ) / total_max_new_tokens + total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs) + ) / total_max_new_tokens new_estimate_ratio = min(1.0, new_estimate_ratio) return retracted_reqs, new_estimate_ratio @@ -1395,16 +1396,16 @@ def prepare_for_decode(self): ) def filter_batch( - self, - chunked_req_to_exclude: Optional[Req] = None, - keep_indices: Optional[List[int]] = None, + self, + chunked_req_to_exclude: Optional[Req] = None, + keep_indices: Optional[List[int]] = None, ): if keep_indices is None: keep_indices = [ i for i in range(len(self.reqs)) if not self.reqs[i].finished() - and self.reqs[i] is not chunked_req_to_exclude + and self.reqs[i] is not chunked_req_to_exclude ] if keep_indices is None or len(keep_indices) == 0: @@ -1493,12 +1494,12 @@ def get_model_worker_batch(self) -> ModelWorkerBatch: # Create seq_lens_cpu when needed if ( - ( - global_server_args_dict["use_mla_backend"] - and global_server_args_dict["attention_backend"] == "flashinfer" - ) - or global_server_args_dict["enable_flashmla"] - or global_server_args_dict["attention_backend"] == "fa3" + ( + global_server_args_dict["use_mla_backend"] + and global_server_args_dict["attention_backend"] == "flashinfer" + ) + or global_server_args_dict["enable_flashmla"] + or global_server_args_dict["attention_backend"] == "fa3" ): seq_lens_cpu = self.seq_lens.cpu() else: @@ -1638,13 +1639,13 @@ class ModelWorkerBatch: @triton.jit def write_req_to_token_pool_triton( - req_to_token_ptr, # [max_batch, max_context_len] - req_pool_indices, - pre_lens, - seq_lens, - extend_lens, - out_cache_loc, - req_to_token_ptr_stride: tl.constexpr, + req_to_token_ptr, # [max_batch, max_context_len] + req_pool_indices, + pre_lens, + seq_lens, + extend_lens, + out_cache_loc, + req_to_token_ptr_stride: tl.constexpr, ): BLOCK_SIZE: tl.constexpr = 512 pid = tl.program_id(0) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index a88c13094fa..6153ba4560a 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -24,7 +24,6 @@ import torch import torch.distributed as dist - from sglang.srt.configs.device_config import DeviceConfig from sglang.srt.configs.load_config import LoadConfig from sglang.srt.configs.model_config import AttentionArch, ModelConfig @@ -173,13 +172,14 @@ def __init__( "debug_tensor_dump_inject": server_args.debug_tensor_dump_inject, "n_share_experts_fusion": server_args.n_share_experts_fusion, "disable_shared_experts_fusion": server_args.disable_shared_experts_fusion, + "ep_num_redundant_experts": server_args.ep_num_redundant_experts, "use_mla_backend": self.use_mla_backend, } ) set_global_expert_location_metadata(expert_location_metadata) # CPU offload - set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3)) + set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024 ** 3)) # Get memory before model loading min_per_gpu_memory = self.init_torch_distributed() @@ -926,7 +926,7 @@ def init_double_sparsity_channel_config(self, selected_channel): key = "model.layers." + str(i) + ".self_attn" + selected_channel self.sorted_channels.append( torch.tensor(channel_config[key])[ - :, : self.server_args.ds_heavy_channel_num + :, : self.server_args.ds_heavy_channel_num ] .contiguous() .cuda() From cf760670f35de50315c10897cb603a6f721cd000 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 16:58:43 +0800 Subject: [PATCH 0320/1089] more --- python/sglang/srt/layers/moe/ep_moe/layer.py | 13 ++++++------- python/sglang/srt/layers/moe/topk.py | 3 +-- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index e61e2e02cdd..cfd2f3c671a 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -2,8 +2,7 @@ from typing import Callable, List, Optional, Tuple import torch - -from sglang.srt.managers.schedule_batch import global_expert_location_metadata +from sglang.srt.managers.schedule_batch import get_global_expert_location_metadata try: from deep_gemm import ( @@ -260,7 +259,7 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): BLOCK_SIZE=512, ) - seg_indptr_cur_rank = seg_indptr[self.start_expert_id : self.end_expert_id + 2] + seg_indptr_cur_rank = seg_indptr[self.start_expert_id: self.end_expert_id + 2] weight_indices_cur_rank = torch.arange( 0, self.num_experts_per_partition, @@ -413,7 +412,7 @@ def weight_loader( expert_id: int, ) -> None: physical_expert_ids = ( - global_expert_location_metadata.logical_to_global_physical(expert_id) + get_global_expert_location_metadata().logical_to_global_physical(expert_id) ) for physical_expert_id in physical_expert_ids: self._weight_loader_physical( @@ -457,7 +456,7 @@ def _weight_loader_physical( elif shard_id == "w1": param.data[expert_id][: self.intermediate_size, :] = loaded_weight elif shard_id == "w3": - param.data[expert_id][self.intermediate_size :, :] = loaded_weight + param.data[expert_id][self.intermediate_size:, :] = loaded_weight else: raise ValueError(f"Expected shard_id w1,w2 or w3 but got {shard_id}") @@ -489,11 +488,11 @@ def _load_fp8_scale( block_n, block_k = self.block_shape[0], self.block_shape[1] if shard_id == "w1": param_data[expert_id][ - : (self.intermediate_size + block_n - 1) // block_n, : + : (self.intermediate_size + block_n - 1) // block_n, : ] = loaded_weight elif shard_id == "w3": param_data[expert_id][ - (self.intermediate_size + block_n - 1) // block_n :, : + (self.intermediate_size + block_n - 1) // block_n:, : ] = loaded_weight else: # w2 param_data[expert_id] = loaded_weight diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index f3d42512ec4..f16ac86c7fa 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -16,7 +16,6 @@ import torch import torch.nn.functional as F - from sglang.srt.managers.expert_distribution import expert_distribution_recorder from sglang.srt.managers.schedule_batch import ( global_expert_location_metadata, @@ -301,7 +300,7 @@ def select_experts( ) # TODO this is inefficient, and I will fuse into existing kernels - topk_ids = global_expert_location_metadata.chosen_logical_to_physical_map[topk_ids] + topk_ids = get_global_expert_location_metadata().chosen_logical_to_physical_map[topk_ids] expert_distribution_recorder.on_select_experts(topk_ids=topk_ids) From ee32245c0ede514ec69c8e74f4033a8154eabca0 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 16:58:51 +0800 Subject: [PATCH 0321/1089] more --- python/sglang/srt/layers/moe/topk.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index f16ac86c7fa..296e2da1326 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -19,7 +19,7 @@ from sglang.srt.managers.expert_distribution import expert_distribution_recorder from sglang.srt.managers.schedule_batch import ( global_expert_location_metadata, - global_server_args_dict, + global_server_args_dict, get_global_expert_location_metadata, ) from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip From 7ad2a48ab66f58bfd1d6b1d3b0dccc1f0b864588 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 16:59:23 +0800 Subject: [PATCH 0322/1089] more --- python/sglang/srt/layers/moe/topk.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 296e2da1326..468d2279045 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -300,7 +300,7 @@ def select_experts( ) # TODO this is inefficient, and I will fuse into existing kernels - topk_ids = get_global_expert_location_metadata().chosen_logical_to_physical_map[topk_ids] + topk_ids = get_global_expert_location_metadata().logical_to_rank_chosen_physical_map[rank, layer_id, topk_ids] expert_distribution_recorder.on_select_experts(topk_ids=topk_ids) From d845b4e826ebbd16751a99e9d0655b4fe3f68b16 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 17:00:00 +0800 Subject: [PATCH 0323/1089] more --- python/sglang/srt/layers/moe/ep_moe/layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index cfd2f3c671a..62e29104c7d 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -412,7 +412,7 @@ def weight_loader( expert_id: int, ) -> None: physical_expert_ids = ( - get_global_expert_location_metadata().logical_to_global_physical(expert_id) + get_global_expert_location_metadata().logical_to_all_physical(layer_id, expert_id) ) for physical_expert_id in physical_expert_ids: self._weight_loader_physical( From b7a52d5b18e590586c441122b0153c45bee91b7e Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 17:02:16 +0800 Subject: [PATCH 0324/1089] more --- python/sglang/srt/layers/moe/ep_moe/layer.py | 6 +++++- python/sglang/srt/layers/moe/fused_moe_triton/layer.py | 1 + 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 62e29104c7d..eec4c84d87e 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -132,6 +132,7 @@ def __init__( top_k: int, hidden_size: int, intermediate_size: int, + layer_id: int, params_dtype: Optional[torch.dtype] = None, renormalize: bool = True, use_grouped_topk: bool = False, @@ -154,6 +155,7 @@ def __init__( ) self.tp_rank = get_tensor_model_parallel_rank() + self.layer_id = layer_id self.num_experts = num_experts assert self.num_experts % self.tp_size == 0 self.num_experts_per_partition = self.num_experts // self.tp_size @@ -412,7 +414,7 @@ def weight_loader( expert_id: int, ) -> None: physical_expert_ids = ( - get_global_expert_location_metadata().logical_to_all_physical(layer_id, expert_id) + get_global_expert_location_metadata().logical_to_all_physical(self.layer_id, expert_id) ) for physical_expert_id in physical_expert_ids: self._weight_loader_physical( @@ -823,6 +825,7 @@ def __init__( top_k: int, hidden_size: int, intermediate_size: int, + layer_id: int, params_dtype: Optional[torch.dtype] = None, renormalize: bool = True, use_grouped_topk: bool = False, @@ -841,6 +844,7 @@ def __init__( top_k, hidden_size, intermediate_size, + layer_id, params_dtype, renormalize, use_grouped_topk, diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index a33cf691fa5..4b23ad8d95d 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -268,6 +268,7 @@ def __init__( top_k: int, hidden_size: int, intermediate_size: int, + layer_id: Optional[int] = None, params_dtype: Optional[torch.dtype] = None, reduce_results: bool = False, renormalize: bool = True, From da995b194125fb6330f66fcbd59e2a0b95befc4e Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 17:03:04 +0800 Subject: [PATCH 0325/1089] more --- python/sglang/srt/models/deepseek_v2.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 2380031fc73..2722ce871f3 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -167,6 +167,7 @@ def __init__( self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, + layer_id: int = -1, prefix: str = "", ): super().__init__() @@ -178,6 +179,7 @@ def __init__( if global_server_args_dict["n_share_experts_fusion"] is not None else 0 ) + self.layer_id = layer_id self.routed_scaling_factor = config.routed_scaling_factor if self.tp_size > config.n_routed_experts: @@ -205,6 +207,7 @@ def __init__( top_k=config.num_experts_per_tok + min(self.n_share_experts_fusion, 1), hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, + layer_id=self.layer_id, renormalize=config.norm_topk_prob, quant_config=quant_config, use_grouped_topk=True, @@ -1084,6 +1087,7 @@ def is_sparse_layer(l: int): config=config, quant_config=quant_config, prefix=add_prefix("mlp", prefix), + layer_id=self.layer_id, ) self.is_sparse = True else: From abbb2af357bad06d65217f5c6dc7b6c604f2d906 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 17:03:31 +0800 Subject: [PATCH 0326/1089] more --- python/sglang/srt/models/deepseek_v2.py | 61 ++++++++++++------------- 1 file changed, 30 insertions(+), 31 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 2722ce871f3..f9233ac189e 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -22,10 +22,6 @@ import torch import torch.nn.functional as F -from torch import nn -from tqdm import tqdm -from transformers import PretrainedConfig - from sglang.srt.distributed import ( get_tensor_model_parallel_world_size, parallel_state, @@ -74,6 +70,9 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import DeepEPMode, add_prefix, is_cuda, is_hip +from torch import nn +from tqdm import tqdm +from transformers import PretrainedConfig _is_hip = is_hip() _is_cuda = is_cuda() @@ -414,7 +413,7 @@ def __init__( self.num_heads = num_heads assert num_heads % attn_tp_size == 0 self.num_local_heads = num_heads // attn_tp_size - self.scaling = self.qk_head_dim**-0.5 + self.scaling = self.qk_head_dim ** -0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings @@ -527,12 +526,12 @@ def forward( kv = self.kv_b_proj(kv_a)[0] kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - k_pe = latent_cache[:, :, self.kv_lora_rank :] + k_pe = latent_cache[:, :, self.kv_lora_rank:] q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - q[..., self.qk_nope_head_dim :] = q_pe + q[..., self.qk_nope_head_dim:] = q_pe k = torch.empty_like(q) k[..., : self.qk_nope_head_dim] = k_nope - k[..., self.qk_nope_head_dim :] = k_pe + k[..., self.qk_nope_head_dim:] = k_pe q = torch.nn.functional.pad(q, [0, 256 - self.qk_head_dim], value=0).view( -1, self.num_local_heads * 256 ) @@ -544,8 +543,8 @@ def forward( ) attn_output = self.attn(q, k, v, forward_batch) attn_output = attn_output.view(-1, self.num_local_heads, 256)[ - ..., : self.v_head_dim - ].reshape(-1, self.num_local_heads * self.v_head_dim) + ..., : self.v_head_dim + ].reshape(-1, self.num_local_heads * self.v_head_dim) output, _ = self.o_proj(attn_output) return output @@ -586,7 +585,7 @@ def __init__( self.num_heads = num_heads assert num_heads % attn_tp_size == 0 self.num_local_heads = num_heads // attn_tp_size - self.scaling = self.qk_head_dim**-0.5 + self.scaling = self.qk_head_dim ** -0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings @@ -773,16 +772,16 @@ def forward_normal( kv = self.kv_b_proj(kv_a)[0] kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) k_nope = kv[..., : self.qk_nope_head_dim] - v = kv[..., self.qk_nope_head_dim :] - k_pe = latent_cache[:, :, self.kv_lora_rank :] + v = kv[..., self.qk_nope_head_dim:] + k_pe = latent_cache[:, :, self.kv_lora_rank:] q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - q[..., self.qk_nope_head_dim :] = q_pe + q[..., self.qk_nope_head_dim:] = q_pe k = torch.empty_like(q) k[..., : self.qk_nope_head_dim] = k_nope - k[..., self.qk_nope_head_dim :] = k_pe + k[..., self.qk_nope_head_dim:] = k_pe latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1) - latent_cache[:, :, self.kv_lora_rank :] = k_pe + latent_cache[:, :, self.kv_lora_rank:] = k_pe # Save latent cache forward_batch.token_to_kv_pool.set_kv_buffer( @@ -835,11 +834,11 @@ def forward_absorb( v_input = self.kv_a_layernorm(v_input.contiguous()).unsqueeze(1) k_input = latent_cache.unsqueeze(1) k_input[..., : self.kv_lora_rank] = v_input - k_pe = k_input[..., self.kv_lora_rank :] + k_pe = k_input[..., self.kv_lora_rank:] q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - q_input[..., self.kv_lora_rank :] = q_pe - k_input[..., self.kv_lora_rank :] = k_pe + q_input[..., self.kv_lora_rank:] = q_pe + k_input[..., self.kv_lora_rank:] = k_pe attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch) attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank) @@ -915,15 +914,15 @@ def forward_absorb_fused_mla_rope( k_input[..., : self.kv_lora_rank] = v_input if not enable_rope_fusion: - k_pe = k_input[..., self.kv_lora_rank :] + k_pe = k_input[..., self.kv_lora_rank:] q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - q_input[..., self.kv_lora_rank :] = q_pe - k_input[..., self.kv_lora_rank :] = k_pe + q_input[..., self.kv_lora_rank:] = q_pe + k_input[..., self.kv_lora_rank:] = k_pe k_pe_output = None else: - k_pe_output = torch.empty_like(k_input[..., self.kv_lora_rank :]) + k_pe_output = torch.empty_like(k_input[..., self.kv_lora_rank:]) - q_input[..., self.kv_lora_rank :] = q_pe + q_input[..., self.kv_lora_rank:] = q_pe # attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch) # Use Fused ROPE with use_rope=OFF. @@ -980,7 +979,7 @@ def forward_absorb_fused_mla_rope( ) if enable_rope_fusion: - k_input[..., self.kv_lora_rank :] = k_pe_output + k_input[..., self.kv_lora_rank:] = k_pe_output forward_batch.token_to_kv_pool.set_kv_buffer( self.attn_mqa, forward_batch.out_cache_loc, k_input, None ) @@ -1499,7 +1498,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): self.config.moe_layer_freq, ), desc=f"Cloning {self.n_share_experts_fusion} " - "replicas of the shared expert into MoE", + "replicas of the shared expert into MoE", ): for num_repeat in range(self.n_share_experts_fusion): for suffix in suffix_list: @@ -1530,11 +1529,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", num_experts=self.config.n_routed_experts - + ( - self.n_share_experts_fusion - if self.n_share_experts_fusion is not None - else 0 - ), + + ( + self.n_share_experts_fusion + if self.n_share_experts_fusion is not None + else 0 + ), ) params_dict = dict(self.named_parameters()) From eedfcb10edb973364bb143d230271031da16e23c Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 17:04:02 +0800 Subject: [PATCH 0327/1089] more --- python/sglang/srt/layers/moe/topk.py | 2 +- python/sglang/srt/managers/expert_location.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 468d2279045..8a54a857e84 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -300,7 +300,7 @@ def select_experts( ) # TODO this is inefficient, and I will fuse into existing kernels - topk_ids = get_global_expert_location_metadata().logical_to_rank_chosen_physical_map[rank, layer_id, topk_ids] + topk_ids = get_global_expert_location_metadata().logical_to_rank_dispatch_physical_map[rank, layer_id, topk_ids] expert_distribution_recorder.on_select_experts(topk_ids=topk_ids) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index a6e565c4107..f2749dc1b0e 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -15,7 +15,7 @@ class ExpertLocationMetadata: num_logical_experts: int physical_to_logical_map: torch.Tensor # (layers, num_physical_experts) logical_to_all_physical_map: torch.Tensor # (layers, num_logical_experts, X) - logical_to_rank_chosen_physical_map: ( + logical_to_rank_dispatch_physical_map: ( torch.Tensor ) # (num_gpus, layers, num_logical_experts) @@ -46,7 +46,7 @@ def init_new(num_layers: int, num_logical_experts: int): logical_to_all_physical_map=torch.arange(0, num_physical_experts).repeat( num_layers, 1 )[..., None], - logical_to_rank_chosen_physical_map=torch.arange( + logical_to_rank_dispatch_physical_map=torch.arange( 0, num_physical_experts ).repeat(num_layers, 1)[..., None], ) From 3852d05be0c9f8c6970fa76e49b91c0d03f64f3f Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 17:04:14 +0800 Subject: [PATCH 0328/1089] more --- python/sglang/srt/managers/expert_location.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index f2749dc1b0e..5a898752077 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -15,9 +15,8 @@ class ExpertLocationMetadata: num_logical_experts: int physical_to_logical_map: torch.Tensor # (layers, num_physical_experts) logical_to_all_physical_map: torch.Tensor # (layers, num_logical_experts, X) - logical_to_rank_dispatch_physical_map: ( - torch.Tensor - ) # (num_gpus, layers, num_logical_experts) + # (num_gpus, layers, num_logical_experts) + logical_to_rank_dispatch_physical_map: torch.Tensor # -------------------------------- construction and mutation ------------------------------------ @@ -32,7 +31,7 @@ def from_model_config(model_config: ModelConfig): def init_new(num_layers: int, num_logical_experts: int): # TODO handle more complex cases like duplicating experts on different GPUs num_local_physical_experts = ( - num_logical_experts // get_tensor_model_parallel_world_size() + num_logical_experts // get_tensor_model_parallel_world_size() ) num_physical_experts = num_logical_experts @@ -64,7 +63,7 @@ def physical_to_local_physical(self, global_physical_expert_index: int): return global_physical_expert_index % self.num_local_physical_experts def logical_to_all_physical( - self, layer_id: int, logical_expert_id: int + self, layer_id: int, logical_expert_id: int ) -> List[int]: return [ physical_expert_id From cd4caee38d430a08b7e7cf301b903b9affe7e0e3 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 17:06:38 +0800 Subject: [PATCH 0329/1089] more --- python/sglang/srt/layers/moe/ep_moe/layer.py | 1 + python/sglang/srt/layers/moe/topk.py | 6 ++++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index eec4c84d87e..ac19f03494f 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -224,6 +224,7 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): num_expert_group=self.num_expert_group, correction_bias=self.correction_bias, custom_routing_function=self.custom_routing_function, + layer_id=self.layer_id, ) reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess( diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 8a54a857e84..bc7e835feb0 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -248,6 +248,7 @@ def select_experts( custom_routing_function: Optional[Callable] = None, correction_bias: Optional[torch.Tensor] = None, torch_native: bool = False, + expert_logical_to_rank_dispatch_physical_map: Optional[torch.Tensor] = None, ): n_share_experts_fusion = 0 if global_server_args_dict["n_share_experts_fusion"] is not None: @@ -299,8 +300,9 @@ def select_experts( renormalize=renormalize, ) - # TODO this is inefficient, and I will fuse into existing kernels - topk_ids = get_global_expert_location_metadata().logical_to_rank_dispatch_physical_map[rank, layer_id, topk_ids] + if expert_logical_to_rank_dispatch_physical_map is not None: + # TODO this is inefficient, and I will fuse into existing kernels + topk_ids = expert_logical_to_rank_dispatch_physical_map[topk_ids] expert_distribution_recorder.on_select_experts(topk_ids=topk_ids) From 031ce90133d30c617b9b0890789903752c0afdd7 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 17:06:51 +0800 Subject: [PATCH 0330/1089] more --- python/sglang/srt/layers/moe/ep_moe/layer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index ac19f03494f..e1b91d198b4 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -224,7 +224,8 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): num_expert_group=self.num_expert_group, correction_bias=self.correction_bias, custom_routing_function=self.custom_routing_function, - layer_id=self.layer_id, + expert_logical_to_rank_dispatch_physical_map= + get_global_expert_location_metadata().logical_to_rank_dispatch_physical_map[rank, self.layer_id, :], ) reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess( From 4623be41d09624aa161500a3e443a17816674517 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 17:07:06 +0800 Subject: [PATCH 0331/1089] more --- python/sglang/srt/layers/moe/ep_moe/layer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index e1b91d198b4..06c568237ed 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -224,8 +224,7 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): num_expert_group=self.num_expert_group, correction_bias=self.correction_bias, custom_routing_function=self.custom_routing_function, - expert_logical_to_rank_dispatch_physical_map= - get_global_expert_location_metadata().logical_to_rank_dispatch_physical_map[rank, self.layer_id, :], + expert_logical_to_rank_dispatch_physical_map=get_global_expert_location_metadata().logical_to_rank_dispatch_physical_map[self.tp_rank, self.layer_id, :], ) reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess( From c19f4d833533f3e1980c607940854fa5fe0df881 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 17:07:38 +0800 Subject: [PATCH 0332/1089] fmt --- python/sglang/srt/layers/moe/ep_moe/layer.py | 17 ++++-- python/sglang/srt/layers/moe/topk.py | 4 +- python/sglang/srt/managers/expert_location.py | 4 +- python/sglang/srt/models/deepseek_v2.py | 61 ++++++++++--------- 4 files changed, 47 insertions(+), 39 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 06c568237ed..8d3d4687ecc 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -2,6 +2,7 @@ from typing import Callable, List, Optional, Tuple import torch + from sglang.srt.managers.schedule_batch import get_global_expert_location_metadata try: @@ -224,7 +225,9 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): num_expert_group=self.num_expert_group, correction_bias=self.correction_bias, custom_routing_function=self.custom_routing_function, - expert_logical_to_rank_dispatch_physical_map=get_global_expert_location_metadata().logical_to_rank_dispatch_physical_map[self.tp_rank, self.layer_id, :], + expert_logical_to_rank_dispatch_physical_map=get_global_expert_location_metadata().logical_to_rank_dispatch_physical_map[ + self.tp_rank, self.layer_id, : + ], ) reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess( @@ -262,7 +265,7 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): BLOCK_SIZE=512, ) - seg_indptr_cur_rank = seg_indptr[self.start_expert_id: self.end_expert_id + 2] + seg_indptr_cur_rank = seg_indptr[self.start_expert_id : self.end_expert_id + 2] weight_indices_cur_rank = torch.arange( 0, self.num_experts_per_partition, @@ -415,7 +418,9 @@ def weight_loader( expert_id: int, ) -> None: physical_expert_ids = ( - get_global_expert_location_metadata().logical_to_all_physical(self.layer_id, expert_id) + get_global_expert_location_metadata().logical_to_all_physical( + self.layer_id, expert_id + ) ) for physical_expert_id in physical_expert_ids: self._weight_loader_physical( @@ -459,7 +464,7 @@ def _weight_loader_physical( elif shard_id == "w1": param.data[expert_id][: self.intermediate_size, :] = loaded_weight elif shard_id == "w3": - param.data[expert_id][self.intermediate_size:, :] = loaded_weight + param.data[expert_id][self.intermediate_size :, :] = loaded_weight else: raise ValueError(f"Expected shard_id w1,w2 or w3 but got {shard_id}") @@ -491,11 +496,11 @@ def _load_fp8_scale( block_n, block_k = self.block_shape[0], self.block_shape[1] if shard_id == "w1": param_data[expert_id][ - : (self.intermediate_size + block_n - 1) // block_n, : + : (self.intermediate_size + block_n - 1) // block_n, : ] = loaded_weight elif shard_id == "w3": param_data[expert_id][ - (self.intermediate_size + block_n - 1) // block_n:, : + (self.intermediate_size + block_n - 1) // block_n :, : ] = loaded_weight else: # w2 param_data[expert_id] = loaded_weight diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index bc7e835feb0..b41fc5430bc 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -16,10 +16,12 @@ import torch import torch.nn.functional as F + from sglang.srt.managers.expert_distribution import expert_distribution_recorder from sglang.srt.managers.schedule_batch import ( + get_global_expert_location_metadata, global_expert_location_metadata, - global_server_args_dict, get_global_expert_location_metadata, + global_server_args_dict, ) from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 5a898752077..35088ad165d 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -31,7 +31,7 @@ def from_model_config(model_config: ModelConfig): def init_new(num_layers: int, num_logical_experts: int): # TODO handle more complex cases like duplicating experts on different GPUs num_local_physical_experts = ( - num_logical_experts // get_tensor_model_parallel_world_size() + num_logical_experts // get_tensor_model_parallel_world_size() ) num_physical_experts = num_logical_experts @@ -63,7 +63,7 @@ def physical_to_local_physical(self, global_physical_expert_index: int): return global_physical_expert_index % self.num_local_physical_experts def logical_to_all_physical( - self, layer_id: int, logical_expert_id: int + self, layer_id: int, logical_expert_id: int ) -> List[int]: return [ physical_expert_id diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index f9233ac189e..2722ce871f3 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -22,6 +22,10 @@ import torch import torch.nn.functional as F +from torch import nn +from tqdm import tqdm +from transformers import PretrainedConfig + from sglang.srt.distributed import ( get_tensor_model_parallel_world_size, parallel_state, @@ -70,9 +74,6 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import DeepEPMode, add_prefix, is_cuda, is_hip -from torch import nn -from tqdm import tqdm -from transformers import PretrainedConfig _is_hip = is_hip() _is_cuda = is_cuda() @@ -413,7 +414,7 @@ def __init__( self.num_heads = num_heads assert num_heads % attn_tp_size == 0 self.num_local_heads = num_heads // attn_tp_size - self.scaling = self.qk_head_dim ** -0.5 + self.scaling = self.qk_head_dim**-0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings @@ -526,12 +527,12 @@ def forward( kv = self.kv_b_proj(kv_a)[0] kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - k_pe = latent_cache[:, :, self.kv_lora_rank:] + k_pe = latent_cache[:, :, self.kv_lora_rank :] q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - q[..., self.qk_nope_head_dim:] = q_pe + q[..., self.qk_nope_head_dim :] = q_pe k = torch.empty_like(q) k[..., : self.qk_nope_head_dim] = k_nope - k[..., self.qk_nope_head_dim:] = k_pe + k[..., self.qk_nope_head_dim :] = k_pe q = torch.nn.functional.pad(q, [0, 256 - self.qk_head_dim], value=0).view( -1, self.num_local_heads * 256 ) @@ -543,8 +544,8 @@ def forward( ) attn_output = self.attn(q, k, v, forward_batch) attn_output = attn_output.view(-1, self.num_local_heads, 256)[ - ..., : self.v_head_dim - ].reshape(-1, self.num_local_heads * self.v_head_dim) + ..., : self.v_head_dim + ].reshape(-1, self.num_local_heads * self.v_head_dim) output, _ = self.o_proj(attn_output) return output @@ -585,7 +586,7 @@ def __init__( self.num_heads = num_heads assert num_heads % attn_tp_size == 0 self.num_local_heads = num_heads // attn_tp_size - self.scaling = self.qk_head_dim ** -0.5 + self.scaling = self.qk_head_dim**-0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings @@ -772,16 +773,16 @@ def forward_normal( kv = self.kv_b_proj(kv_a)[0] kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) k_nope = kv[..., : self.qk_nope_head_dim] - v = kv[..., self.qk_nope_head_dim:] - k_pe = latent_cache[:, :, self.kv_lora_rank:] + v = kv[..., self.qk_nope_head_dim :] + k_pe = latent_cache[:, :, self.kv_lora_rank :] q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - q[..., self.qk_nope_head_dim:] = q_pe + q[..., self.qk_nope_head_dim :] = q_pe k = torch.empty_like(q) k[..., : self.qk_nope_head_dim] = k_nope - k[..., self.qk_nope_head_dim:] = k_pe + k[..., self.qk_nope_head_dim :] = k_pe latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1) - latent_cache[:, :, self.kv_lora_rank:] = k_pe + latent_cache[:, :, self.kv_lora_rank :] = k_pe # Save latent cache forward_batch.token_to_kv_pool.set_kv_buffer( @@ -834,11 +835,11 @@ def forward_absorb( v_input = self.kv_a_layernorm(v_input.contiguous()).unsqueeze(1) k_input = latent_cache.unsqueeze(1) k_input[..., : self.kv_lora_rank] = v_input - k_pe = k_input[..., self.kv_lora_rank:] + k_pe = k_input[..., self.kv_lora_rank :] q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - q_input[..., self.kv_lora_rank:] = q_pe - k_input[..., self.kv_lora_rank:] = k_pe + q_input[..., self.kv_lora_rank :] = q_pe + k_input[..., self.kv_lora_rank :] = k_pe attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch) attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank) @@ -914,15 +915,15 @@ def forward_absorb_fused_mla_rope( k_input[..., : self.kv_lora_rank] = v_input if not enable_rope_fusion: - k_pe = k_input[..., self.kv_lora_rank:] + k_pe = k_input[..., self.kv_lora_rank :] q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - q_input[..., self.kv_lora_rank:] = q_pe - k_input[..., self.kv_lora_rank:] = k_pe + q_input[..., self.kv_lora_rank :] = q_pe + k_input[..., self.kv_lora_rank :] = k_pe k_pe_output = None else: - k_pe_output = torch.empty_like(k_input[..., self.kv_lora_rank:]) + k_pe_output = torch.empty_like(k_input[..., self.kv_lora_rank :]) - q_input[..., self.kv_lora_rank:] = q_pe + q_input[..., self.kv_lora_rank :] = q_pe # attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch) # Use Fused ROPE with use_rope=OFF. @@ -979,7 +980,7 @@ def forward_absorb_fused_mla_rope( ) if enable_rope_fusion: - k_input[..., self.kv_lora_rank:] = k_pe_output + k_input[..., self.kv_lora_rank :] = k_pe_output forward_batch.token_to_kv_pool.set_kv_buffer( self.attn_mqa, forward_batch.out_cache_loc, k_input, None ) @@ -1498,7 +1499,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): self.config.moe_layer_freq, ), desc=f"Cloning {self.n_share_experts_fusion} " - "replicas of the shared expert into MoE", + "replicas of the shared expert into MoE", ): for num_repeat in range(self.n_share_experts_fusion): for suffix in suffix_list: @@ -1529,11 +1530,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", num_experts=self.config.n_routed_experts - + ( - self.n_share_experts_fusion - if self.n_share_experts_fusion is not None - else 0 - ), + + ( + self.n_share_experts_fusion + if self.n_share_experts_fusion is not None + else 0 + ), ) params_dict = dict(self.named_parameters()) From b70afc146da9bfa858b76adf1e82acc06b1289c1 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 17:08:53 +0800 Subject: [PATCH 0333/1089] fmt --- python/sglang/srt/managers/schedule_batch.py | 200 +++++++++--------- .../sglang/srt/model_executor/model_runner.py | 5 +- python/sglang/srt/models/deepseek_v2.py | 4 +- python/sglang/srt/server_args.py | 74 +++---- 4 files changed, 144 insertions(+), 139 deletions(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index df4be21fb32..39840545a1d 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -271,17 +271,17 @@ def hash_feature(f): def is_audio(self): return ( - self.modality == Modality.AUDIO + self.modality == Modality.AUDIO ) and not MultimodalDataItem.is_empty_list(self.audio_features) def is_image(self): return ( - self.modality == Modality.IMAGE or self.modality == Modality.MULTI_IMAGES + self.modality == Modality.IMAGE or self.modality == Modality.MULTI_IMAGES ) and not MultimodalDataItem.is_empty_list(self.pixel_values) def is_video(self): return ( - self.modality == Modality.VIDEO + self.modality == Modality.VIDEO ) and not MultimodalDataItem.is_empty_list(self.pixel_values) def validate(self): @@ -390,24 +390,24 @@ class Req: """The input and output status of a request.""" def __init__( - self, - rid: str, - origin_input_text: str, - origin_input_ids: Tuple[int], - sampling_params: SamplingParams, - return_logprob: bool = False, - top_logprobs_num: int = 0, - token_ids_logprob: List[int] = None, - stream: bool = False, - origin_input_ids_unpadded: Optional[Tuple[int]] = None, - lora_path: Optional[str] = None, - input_embeds: Optional[List[List[float]]] = None, - session_id: Optional[str] = None, - custom_logit_processor: Optional[str] = None, - return_hidden_states: bool = False, - eos_token_ids: Optional[Set[int]] = None, - bootstrap_host: Optional[str] = None, - bootstrap_room: Optional[int] = None, + self, + rid: str, + origin_input_text: str, + origin_input_ids: Tuple[int], + sampling_params: SamplingParams, + return_logprob: bool = False, + top_logprobs_num: int = 0, + token_ids_logprob: List[int] = None, + stream: bool = False, + origin_input_ids_unpadded: Optional[Tuple[int]] = None, + lora_path: Optional[str] = None, + input_embeds: Optional[List[List[float]]] = None, + session_id: Optional[str] = None, + custom_logit_processor: Optional[str] = None, + return_hidden_states: bool = False, + eos_token_ids: Optional[Set[int]] = None, + bootstrap_host: Optional[str] = None, + bootstrap_room: Optional[int] = None, ): # Input and output info self.rid = rid @@ -573,9 +573,9 @@ def finished(self) -> bool: return self.finished_reason is not None def init_next_round_input( - self, - tree_cache: Optional[BasePrefixCache] = None, - enable_hierarchical_cache=False, + self, + tree_cache: Optional[BasePrefixCache] = None, + enable_hierarchical_cache=False, ): self.fill_ids = self.origin_input_ids + self.output_ids if tree_cache is not None: @@ -621,7 +621,7 @@ def init_incremental_detokenize(self): ) all_ids = self.origin_input_ids_unpadded + self.output_ids - return all_ids[self.surr_offset:], self.read_offset - self.surr_offset + return all_ids[self.surr_offset :], self.read_offset - self.surr_offset def check_finished(self): if self.finished(): @@ -653,7 +653,7 @@ def check_finished(self): matched_eos |= last_token_id == self.tokenizer.eos_token_id if self.tokenizer.additional_stop_token_ids: matched_eos |= ( - last_token_id in self.tokenizer.additional_stop_token_ids + last_token_id in self.tokenizer.additional_stop_token_ids ) if matched_eos: self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id) @@ -662,7 +662,7 @@ def check_finished(self): # Check stop strings if len(self.sampling_params.stop_strs) > 0: tail_str = self.tokenizer.decode( - self.output_ids[-(self.sampling_params.stop_str_max_len + 1):] + self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :] ) for stop_str in self.sampling_params.stop_strs: @@ -778,15 +778,15 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): @classmethod def init_new( - cls, - reqs: List[Req], - req_to_token_pool: ReqToTokenPool, - token_to_kv_pool_allocator: TokenToKVPoolAllocator, - tree_cache: BasePrefixCache, - model_config: ModelConfig, - enable_overlap: bool, - spec_algorithm: SpeculativeAlgorithm, - enable_custom_logit_processor: bool, + cls, + reqs: List[Req], + req_to_token_pool: ReqToTokenPool, + token_to_kv_pool_allocator: TokenToKVPoolAllocator, + tree_cache: BasePrefixCache, + model_config: ModelConfig, + enable_overlap: bool, + spec_algorithm: SpeculativeAlgorithm, + enable_custom_logit_processor: bool, ): return_logprob = any(req.return_logprob for req in reqs) @@ -850,17 +850,17 @@ def alloc_token_slots(self, num_tokens: int, backup_state: bool = False): return out_cache_loc def alloc_paged_token_slots_extend( - self, - prefix_lens: torch.Tensor, - seq_lens: torch.Tensor, - last_loc: torch.Tensor, - extend_num_tokens: int, - backup_state: bool = False, + self, + prefix_lens: torch.Tensor, + seq_lens: torch.Tensor, + last_loc: torch.Tensor, + extend_num_tokens: int, + backup_state: bool = False, ): if ( - self.token_to_kv_pool_allocator.available_size() - < extend_num_tokens - + len(seq_lens) * self.token_to_kv_pool_allocator.page_size + self.token_to_kv_pool_allocator.available_size() + < extend_num_tokens + + len(seq_lens) * self.token_to_kv_pool_allocator.page_size ): if self.tree_cache is not None: self.tree_cache.evict( @@ -891,15 +891,15 @@ def alloc_paged_token_slots_extend( return out_cache_loc def alloc_paged_token_slots_decode( - self, - seq_lens: torch.Tensor, - last_loc: torch.Tensor, - backup_state: bool = False, + self, + seq_lens: torch.Tensor, + last_loc: torch.Tensor, + backup_state: bool = False, ): if self.tree_cache is not None: if ( - self.token_to_kv_pool_allocator.available_size() - < len(seq_lens) * self.token_to_kv_pool_allocator.page_size + self.token_to_kv_pool_allocator.available_size() + < len(seq_lens) * self.token_to_kv_pool_allocator.page_size ): self.tree_cache.evict( len(seq_lens) * self.token_to_kv_pool_allocator.page_size, @@ -958,15 +958,15 @@ def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]) # NOTE: the encoder part should be considered as a whole assert len(req.prefix_indices) == 0 input_ids[i] = input_ids[i][encoder_len:] - encoder_out_cache_loc.append(self.out_cache_loc[pt: pt + encoder_len]) + encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len]) decoder_out_cache_loc.append( - self.out_cache_loc[pt + encoder_len: pt + req.extend_input_len] + self.out_cache_loc[pt + encoder_len : pt + req.extend_input_len] ) self.extend_lens[i] -= encoder_len self.extend_num_tokens -= encoder_len else: decoder_out_cache_loc.append( - self.out_cache_loc[pt: pt + req.extend_input_len] + self.out_cache_loc[pt : pt + req.extend_input_len] ) self.prefix_lens[i] -= encoder_len @@ -1005,7 +1005,7 @@ def prepare_for_extend(self): # Init tensors reqs = self.reqs - input_ids = [r.fill_ids[len(r.prefix_indices):] for r in reqs] + input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs] extend_num_tokens = sum(len(ids) for ids in input_ids) seq_lens = [len(r.fill_ids) for r in reqs] prefix_lens = [len(r.prefix_indices) for r in reqs] @@ -1080,8 +1080,8 @@ def prepare_for_extend(self): global_start_idx = req.logprob_start_len logprob_token_ids = req.origin_input_ids[ - global_start_idx + 1: global_end_idx + 1 - ] + global_start_idx + 1 : global_end_idx + 1 + ] extend_input_logprob_token_ids.extend(logprob_token_ids) # We will need req.extend_input_len - req.extend_logprob_start_len number of @@ -1089,9 +1089,9 @@ def prepare_for_extend(self): extend_input_logprob_token_ids.extend( [0] * ( - req.extend_input_len - - req.extend_logprob_start_len - - len(logprob_token_ids) + req.extend_input_len + - req.extend_logprob_start_len + - len(logprob_token_ids) ) ) @@ -1155,7 +1155,7 @@ def prepare_for_extend(self): for i in range(bs): self.req_to_token_pool.write( (req_pool_indices[i], slice(prefix_lens[i], seq_lens[i])), - out_cache_loc[pt: pt + extend_lens[i]], + out_cache_loc[pt : pt + extend_lens[i]], ) pt += extend_lens[i] @@ -1206,9 +1206,9 @@ def new_page_count_next_decode(self): def check_decode_mem(self, buf_multiplier=1): tokens_required = ( - self.new_page_count_next_decode() - * buf_multiplier - * self.token_to_kv_pool_allocator.page_size + self.new_page_count_next_decode() + * buf_multiplier + * self.token_to_kv_pool_allocator.page_size ) if self.token_to_kv_pool_allocator.available_size() >= tokens_required: @@ -1240,27 +1240,27 @@ def get_required_tokens(num_reqs: int): headroom_for_spec_decode = 0 if server_args.speculative_algorithm: headroom_for_spec_decode += ( - num_reqs - * server_args.speculative_eagle_topk - * server_args.speculative_num_steps - + num_reqs * server_args.speculative_num_draft_tokens + num_reqs + * server_args.speculative_eagle_topk + * server_args.speculative_num_steps + + num_reqs * server_args.speculative_num_draft_tokens ) return ( - num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode + num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode ) retracted_reqs = [] seq_lens_cpu = self.seq_lens.cpu().numpy() first_iter = True while ( - self.token_to_kv_pool_allocator.available_size() - < get_required_tokens(len(sorted_indices)) - or first_iter + self.token_to_kv_pool_allocator.available_size() + < get_required_tokens(len(sorted_indices)) + or first_iter ): if len(sorted_indices) == 1: # Corner case: only one request left assert ( - self.token_to_kv_pool_allocator.available_size() > 0 + self.token_to_kv_pool_allocator.available_size() > 0 ), "No space left for only one request" break @@ -1272,18 +1272,18 @@ def get_required_tokens(num_reqs: int): if isinstance(self.tree_cache, ChunkCache): # ChunkCache does not have eviction token_indices = self.req_to_token_pool.req_to_token[ - req.req_pool_idx, : seq_lens_cpu[idx] - ] + req.req_pool_idx, : seq_lens_cpu[idx] + ] self.token_to_kv_pool_allocator.free(token_indices) self.req_to_token_pool.free(req.req_pool_idx) else: # TODO: apply more fine-grained retraction last_uncached_pos = ( - len(req.prefix_indices) // server_args.page_size - ) * server_args.page_size + len(req.prefix_indices) // server_args.page_size + ) * server_args.page_size token_indices = self.req_to_token_pool.req_to_token[ - req.req_pool_idx, last_uncached_pos: seq_lens_cpu[idx] - ] + req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx] + ] self.token_to_kv_pool_allocator.free(token_indices) self.req_to_token_pool.free(req.req_pool_idx) @@ -1292,8 +1292,8 @@ def get_required_tokens(num_reqs: int): # NOTE(lsyin): we should use the newly evictable memory instantly. residual_size = ( - len(sorted_indices) * global_config.retract_decode_steps - - self.token_to_kv_pool_allocator.available_size() + len(sorted_indices) * global_config.retract_decode_steps + - self.token_to_kv_pool_allocator.available_size() ) residual_size = max(0, residual_size) self.tree_cache.evict(residual_size) @@ -1307,8 +1307,8 @@ def get_required_tokens(num_reqs: int): total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs) new_estimate_ratio = ( - total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs) - ) / total_max_new_tokens + total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs) + ) / total_max_new_tokens new_estimate_ratio = min(1.0, new_estimate_ratio) return retracted_reqs, new_estimate_ratio @@ -1396,16 +1396,16 @@ def prepare_for_decode(self): ) def filter_batch( - self, - chunked_req_to_exclude: Optional[Req] = None, - keep_indices: Optional[List[int]] = None, + self, + chunked_req_to_exclude: Optional[Req] = None, + keep_indices: Optional[List[int]] = None, ): if keep_indices is None: keep_indices = [ i for i in range(len(self.reqs)) if not self.reqs[i].finished() - and self.reqs[i] is not chunked_req_to_exclude + and self.reqs[i] is not chunked_req_to_exclude ] if keep_indices is None or len(keep_indices) == 0: @@ -1494,12 +1494,12 @@ def get_model_worker_batch(self) -> ModelWorkerBatch: # Create seq_lens_cpu when needed if ( - ( - global_server_args_dict["use_mla_backend"] - and global_server_args_dict["attention_backend"] == "flashinfer" - ) - or global_server_args_dict["enable_flashmla"] - or global_server_args_dict["attention_backend"] == "fa3" + ( + global_server_args_dict["use_mla_backend"] + and global_server_args_dict["attention_backend"] == "flashinfer" + ) + or global_server_args_dict["enable_flashmla"] + or global_server_args_dict["attention_backend"] == "fa3" ): seq_lens_cpu = self.seq_lens.cpu() else: @@ -1639,13 +1639,13 @@ class ModelWorkerBatch: @triton.jit def write_req_to_token_pool_triton( - req_to_token_ptr, # [max_batch, max_context_len] - req_pool_indices, - pre_lens, - seq_lens, - extend_lens, - out_cache_loc, - req_to_token_ptr_stride: tl.constexpr, + req_to_token_ptr, # [max_batch, max_context_len] + req_pool_indices, + pre_lens, + seq_lens, + extend_lens, + out_cache_loc, + req_to_token_ptr_stride: tl.constexpr, ): BLOCK_SIZE: tl.constexpr = 512 pid = tl.program_id(0) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 6153ba4560a..c4d8933415a 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -24,6 +24,7 @@ import torch import torch.distributed as dist + from sglang.srt.configs.device_config import DeviceConfig from sglang.srt.configs.load_config import LoadConfig from sglang.srt.configs.model_config import AttentionArch, ModelConfig @@ -179,7 +180,7 @@ def __init__( set_global_expert_location_metadata(expert_location_metadata) # CPU offload - set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024 ** 3)) + set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3)) # Get memory before model loading min_per_gpu_memory = self.init_torch_distributed() @@ -926,7 +927,7 @@ def init_double_sparsity_channel_config(self, selected_channel): key = "model.layers." + str(i) + ".self_attn" + selected_channel self.sorted_channels.append( torch.tensor(channel_config[key])[ - :, : self.server_args.ds_heavy_channel_num + :, : self.server_args.ds_heavy_channel_num ] .contiguous() .cuda() diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index e3b1b3a3b32..d0bb51f05ff 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -203,7 +203,9 @@ def __init__( ) self.experts = MoEImpl( - num_experts=config.n_routed_experts + self.n_share_experts_fusion + global_server_args_dict["ep_num_redundant_experts"], + num_experts=config.n_routed_experts + + self.n_share_experts_fusion + + global_server_args_dict["ep_num_redundant_experts"], top_k=config.num_experts_per_tok + min(self.n_share_experts_fusion, 1), hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 6c9a7813746..b7de3cd026b 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -323,7 +323,9 @@ def __post_init__(self): ) if self.ep_num_redundant_experts > 0: - assert self.enable_deepep_moe, "ep_num_redundant_experts currently requires DeepEP MoE" + assert ( + self.enable_deepep_moe + ), "ep_num_redundant_experts currently requires DeepEP MoE" # Speculative Decoding if self.speculative_algorithm == "NEXTN": @@ -418,8 +420,8 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.tokenizer_mode, choices=["auto", "slow"], help="Tokenizer mode. 'auto' will use the fast " - "tokenizer if available, and 'slow' will " - "always use the slow tokenizer.", + "tokenizer if available, and 'slow' will " + "always use the slow tokenizer.", ) parser.add_argument( "--skip-tokenizer-init", @@ -443,21 +445,21 @@ def add_cli_args(parser: argparse.ArgumentParser): "remote", ], help="The format of the model weights to load. " - '"auto" will try to load the weights in the safetensors format ' - "and fall back to the pytorch bin format if safetensors format " - "is not available. " - '"pt" will load the weights in the pytorch bin format. ' - '"safetensors" will load the weights in the safetensors format. ' - '"npcache" will load the weights in pytorch format and store ' - "a numpy cache to speed up the loading. " - '"dummy" will initialize the weights with random values, ' - "which is mainly for profiling." - '"gguf" will load the weights in the gguf format. ' - '"bitsandbytes" will load the weights using bitsandbytes ' - "quantization." - '"layered" loads weights layer by layer so that one can quantize a ' - "layer before loading another to make the peak memory envelope " - "smaller.", + '"auto" will try to load the weights in the safetensors format ' + "and fall back to the pytorch bin format if safetensors format " + "is not available. " + '"pt" will load the weights in the pytorch bin format. ' + '"safetensors" will load the weights in the safetensors format. ' + '"npcache" will load the weights in pytorch format and store ' + "a numpy cache to speed up the loading. " + '"dummy" will initialize the weights with random values, ' + "which is mainly for profiling." + '"gguf" will load the weights in the gguf format. ' + '"bitsandbytes" will load the weights using bitsandbytes ' + "quantization." + '"layered" loads weights layer by layer so that one can quantize a ' + "layer before loading another to make the peak memory envelope " + "smaller.", ) parser.add_argument( "--trust-remote-code", @@ -470,13 +472,13 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.dtype, choices=["auto", "half", "float16", "bfloat16", "float", "float32"], help="Data type for model weights and activations.\n\n" - '* "auto" will use FP16 precision for FP32 and FP16 models, and ' - "BF16 precision for BF16 models.\n" - '* "half" for FP16. Recommended for AWQ quantization.\n' - '* "float16" is the same as "half".\n' - '* "bfloat16" for a balance between precision and range.\n' - '* "float" is shorthand for FP32 precision.\n' - '* "float32" for FP32 precision.', + '* "auto" will use FP16 precision for FP32 and FP16 models, and ' + "BF16 precision for BF16 models.\n" + '* "half" for FP16. Recommended for AWQ quantization.\n' + '* "float16" is the same as "half".\n' + '* "bfloat16" for a balance between precision and range.\n' + '* "float" is shorthand for FP32 precision.\n' + '* "float32" for FP32 precision.', ) parser.add_argument( "--kv-cache-dtype", @@ -511,9 +513,9 @@ def add_cli_args(parser: argparse.ArgumentParser): type=nullable_str, default=None, help="Path to the JSON file containing the KV cache " - "scaling factors. This should generally be supplied, when " - "KV cache dtype is FP8. Otherwise, KV cache scaling factors " - "default to 1.0, which may cause accuracy issues. ", + "scaling factors. This should generally be supplied, when " + "KV cache dtype is FP8. Otherwise, KV cache scaling factors " + "default to 1.0, which may cause accuracy issues. ", ) parser.add_argument( "--context-length", @@ -555,8 +557,8 @@ def add_cli_args(parser: argparse.ArgumentParser): type=str, default=None, help="The specific model version to use. It can be a branch " - "name, a tag name, or a commit id. If unspecified, will use " - "the default version.", + "name, a tag name, or a commit id. If unspecified, will use " + "the default version.", ) # Memory and scheduling parser.add_argument( @@ -576,7 +578,7 @@ def add_cli_args(parser: argparse.ArgumentParser): type=int, default=ServerArgs.max_total_tokens, help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. " - "This option is typically used for development and debugging purposes.", + "This option is typically used for development and debugging purposes.", ) parser.add_argument( "--chunked-prefill-size", @@ -1041,7 +1043,7 @@ def add_cli_args(parser: argparse.ArgumentParser): "--triton-attention-reduce-in-fp32", action="store_true", help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16." - "This only affects Triton attention kernels.", + "This only affects Triton attention kernels.", ) parser.add_argument( "--triton-attention-num-kv-splits", @@ -1054,8 +1056,8 @@ def add_cli_args(parser: argparse.ArgumentParser): type=int, default=ServerArgs.num_continuous_decode_steps, help="Run multiple continuous decoding steps to reduce scheduling overhead. " - "This can potentially increase throughput but may also increase time-to-first-token latency. " - "The default value is 1, meaning only run one decoding step at a time.", + "This can potentially increase throughput but may also increase time-to-first-token latency. " + "The default value is 1, meaning only run one decoding step at a time.", ) parser.add_argument( "--delete-ckpt-after-loading", @@ -1120,7 +1122,7 @@ def add_cli_args(parser: argparse.ArgumentParser): type=int, default=0, help="The number of shared_experts need to be replica to fuse with normal experts in deepseek v3/r1 " - "we use tp_size by default.", + "we use tp_size by default.", ) parser.add_argument( "--disable-shared-experts-fusion", @@ -1134,7 +1136,7 @@ def add_cli_args(parser: argparse.ArgumentParser): type=str, required=False, help="Specify custom warmup functions (csv) to run before server starts eg. --warmups=warmup_name1,warmup_name2 " - "will run the functions `warmup_name1` and `warmup_name2` specified in warmup.py before the server starts listening for requests", + "will run the functions `warmup_name1` and `warmup_name2` specified in warmup.py before the server starts listening for requests", ) # Debug tensor dumps From cc1cbd4a56a3aadc5d929f2a49999f8af43e5649 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 17:09:43 +0800 Subject: [PATCH 0334/1089] more --- python/sglang/srt/managers/expert_location.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 35088ad165d..1dd726ba096 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -2,7 +2,6 @@ from typing import List import torch - from sglang.srt.configs.model_config import ModelConfig from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.model_loader import get_model_architecture @@ -29,11 +28,10 @@ def from_model_config(model_config: ModelConfig): @staticmethod def init_new(num_layers: int, num_logical_experts: int): - # TODO handle more complex cases like duplicating experts on different GPUs - num_local_physical_experts = ( - num_logical_experts // get_tensor_model_parallel_world_size() - ) - num_physical_experts = num_logical_experts + num_physical_experts = TODO + world_size = get_tensor_model_parallel_world_size() + assert num_physical_experts % world_size == 0 + num_local_physical_experts = num_physical_experts // world_size return ExpertLocationMetadata( num_layers=num_layers, @@ -63,7 +61,7 @@ def physical_to_local_physical(self, global_physical_expert_index: int): return global_physical_expert_index % self.num_local_physical_experts def logical_to_all_physical( - self, layer_id: int, logical_expert_id: int + self, layer_id: int, logical_expert_id: int ) -> List[int]: return [ physical_expert_id From 64c3b19eebf0b093784a6c6c886bc9316a66edd2 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 17:10:22 +0800 Subject: [PATCH 0335/1089] more --- python/sglang/srt/managers/expert_location.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 1dd726ba096..9f5938f58f6 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -4,6 +4,7 @@ import torch from sglang.srt.configs.model_config import ModelConfig from sglang.srt.distributed import get_tensor_model_parallel_world_size +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_loader import get_model_architecture @@ -28,7 +29,7 @@ def from_model_config(model_config: ModelConfig): @staticmethod def init_new(num_layers: int, num_logical_experts: int): - num_physical_experts = TODO + num_physical_experts = num_logical_experts + global_server_args_dict["ep_num_redundant_experts"] world_size = get_tensor_model_parallel_world_size() assert num_physical_experts % world_size == 0 num_local_physical_experts = num_physical_experts // world_size From e09f1cf9d51f30cdf6379994d7a342e96c003932 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 17:10:51 +0800 Subject: [PATCH 0336/1089] fmt --- python/sglang/srt/managers/expert_location.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 9f5938f58f6..160e04b3ebf 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -2,6 +2,7 @@ from typing import List import torch + from sglang.srt.configs.model_config import ModelConfig from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.managers.schedule_batch import global_server_args_dict @@ -29,7 +30,9 @@ def from_model_config(model_config: ModelConfig): @staticmethod def init_new(num_layers: int, num_logical_experts: int): - num_physical_experts = num_logical_experts + global_server_args_dict["ep_num_redundant_experts"] + num_physical_experts = ( + num_logical_experts + global_server_args_dict["ep_num_redundant_experts"] + ) world_size = get_tensor_model_parallel_world_size() assert num_physical_experts % world_size == 0 num_local_physical_experts = num_physical_experts // world_size @@ -62,7 +65,7 @@ def physical_to_local_physical(self, global_physical_expert_index: int): return global_physical_expert_index % self.num_local_physical_experts def logical_to_all_physical( - self, layer_id: int, logical_expert_id: int + self, layer_id: int, logical_expert_id: int ) -> List[int]: return [ physical_expert_id From ba13a2d750e694389c2ac101c99572d301251976 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 17:11:01 +0800 Subject: [PATCH 0337/1089] more --- python/sglang/srt/managers/expert_location.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 160e04b3ebf..e5fa2c145cf 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -2,7 +2,6 @@ from typing import List import torch - from sglang.srt.configs.model_config import ModelConfig from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.managers.schedule_batch import global_server_args_dict @@ -31,7 +30,7 @@ def from_model_config(model_config: ModelConfig): @staticmethod def init_new(num_layers: int, num_logical_experts: int): num_physical_experts = ( - num_logical_experts + global_server_args_dict["ep_num_redundant_experts"] + num_logical_experts + global_server_args_dict["ep_num_redundant_experts"] ) world_size = get_tensor_model_parallel_world_size() assert num_physical_experts % world_size == 0 @@ -43,7 +42,7 @@ def init_new(num_layers: int, num_logical_experts: int): num_local_physical_experts=num_local_physical_experts, physical_to_logical_map=torch.arange(0, num_physical_experts).repeat( num_layers, 1 - ), + ) % num_logical_experts, logical_to_all_physical_map=torch.arange(0, num_physical_experts).repeat( num_layers, 1 )[..., None], @@ -65,7 +64,7 @@ def physical_to_local_physical(self, global_physical_expert_index: int): return global_physical_expert_index % self.num_local_physical_experts def logical_to_all_physical( - self, layer_id: int, logical_expert_id: int + self, layer_id: int, logical_expert_id: int ) -> List[int]: return [ physical_expert_id From 78ab92c2ac8b5c15c65f9a812f02c2f89a36c8b6 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 17:11:45 +0800 Subject: [PATCH 0338/1089] more --- python/sglang/srt/managers/expert_location.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index e5fa2c145cf..aaae65faa79 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -43,6 +43,8 @@ def init_new(num_layers: int, num_logical_experts: int): physical_to_logical_map=torch.arange(0, num_physical_experts).repeat( num_layers, 1 ) % num_logical_experts, + # Throw away the redundant experts here - highly inefficient, but we do not care since we will + # use EPLB distribution logic logical_to_all_physical_map=torch.arange(0, num_physical_experts).repeat( num_layers, 1 )[..., None], From fe100869d3e8e41db475501d144e2085e1ffb0d3 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 17:12:06 +0800 Subject: [PATCH 0339/1089] more --- python/sglang/srt/managers/expert_location.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index aaae65faa79..c868f40eef0 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -42,14 +42,15 @@ def init_new(num_layers: int, num_logical_experts: int): num_local_physical_experts=num_local_physical_experts, physical_to_logical_map=torch.arange(0, num_physical_experts).repeat( num_layers, 1 - ) % num_logical_experts, + ) + % num_logical_experts, # Throw away the redundant experts here - highly inefficient, but we do not care since we will # use EPLB distribution logic - logical_to_all_physical_map=torch.arange(0, num_physical_experts).repeat( + logical_to_all_physical_map=torch.arange(0, num_logical_experts).repeat( num_layers, 1 )[..., None], logical_to_rank_dispatch_physical_map=torch.arange( - 0, num_physical_experts + 0, num_logical_experts ).repeat(num_layers, 1)[..., None], ) From 5215648694d07182b1626b9ae523d9943f5ee534 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 17:13:11 +0800 Subject: [PATCH 0340/1089] more --- python/sglang/srt/managers/expert_location.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index c868f40eef0..825e7314c72 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -2,6 +2,7 @@ from typing import List import torch + from sglang.srt.configs.model_config import ModelConfig from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.managers.schedule_batch import global_server_args_dict @@ -30,7 +31,7 @@ def from_model_config(model_config: ModelConfig): @staticmethod def init_new(num_layers: int, num_logical_experts: int): num_physical_experts = ( - num_logical_experts + global_server_args_dict["ep_num_redundant_experts"] + num_logical_experts + global_server_args_dict["ep_num_redundant_experts"] ) world_size = get_tensor_model_parallel_world_size() assert num_physical_experts % world_size == 0 @@ -43,7 +44,7 @@ def init_new(num_layers: int, num_logical_experts: int): physical_to_logical_map=torch.arange(0, num_physical_experts).repeat( num_layers, 1 ) - % num_logical_experts, + % num_logical_experts, # Throw away the redundant experts here - highly inefficient, but we do not care since we will # use EPLB distribution logic logical_to_all_physical_map=torch.arange(0, num_logical_experts).repeat( @@ -67,7 +68,7 @@ def physical_to_local_physical(self, global_physical_expert_index: int): return global_physical_expert_index % self.num_local_physical_experts def logical_to_all_physical( - self, layer_id: int, logical_expert_id: int + self, layer_id: int, logical_expert_id: int ) -> List[int]: return [ physical_expert_id From 8e2c6ff7222e55bbeaf0cc6ce78b6eb4411606ae Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 17:16:00 +0800 Subject: [PATCH 0341/1089] more --- python/sglang/srt/managers/eplb_manager.py | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 python/sglang/srt/managers/eplb_manager.py diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py new file mode 100644 index 00000000000..f3973ec4d14 --- /dev/null +++ b/python/sglang/srt/managers/eplb_manager.py @@ -0,0 +1,2 @@ +class EPLBManager: + pass From adee880db4ca143826eb1f8482dd517ce4798185 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 17:16:50 +0800 Subject: [PATCH 0342/1089] more --- python/sglang/srt/managers/tokenizer_manager.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 93a9507bf19..c72a2ea86d5 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -52,6 +52,7 @@ from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.managers import expert_distribution +from sglang.srt.managers.eplb_manager import EPLBManager from sglang.srt.managers.expert_location import ExpertLocationMetadata from sglang.srt.managers.io_struct import ( AbortReq, @@ -201,6 +202,8 @@ def __init__( revision=server_args.revision, ) + self.eplb_manager = EPLBManager() + # Store states self.no_create_loop = False self.rid_to_state: Dict[str, ReqState] = {} From ff029cc5c014115161fdb1807ae023b5609bf151 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 17:18:27 +0800 Subject: [PATCH 0343/1089] more --- .../sglang/srt/managers/tokenizer_manager.py | 5 +- python/sglang/srt/misc/deepseek_eplb.py | 169 ++++++++++++++++++ 2 files changed, 171 insertions(+), 3 deletions(-) create mode 100644 python/sglang/srt/misc/deepseek_eplb.py diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index c72a2ea86d5..3e626a15ff3 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -45,7 +45,6 @@ import zmq import zmq.asyncio from fastapi import BackgroundTasks - from sglang.srt.aio_rwlock import RWLock from sglang.srt.configs.model_config import ModelConfig from sglang.srt.disaggregation.conn import KVBootstrapServer @@ -965,8 +964,8 @@ def _handle_batch_output( elif isinstance(recv_obj, BatchTokenIDOut): if self.server_args.stream_output and state.obj.stream: output_token_ids = recv_obj.output_ids[i][ - state.last_output_offset : - ] + state.last_output_offset: + ] state.last_output_offset = len(recv_obj.output_ids[i]) else: output_token_ids = recv_obj.output_ids[i] diff --git a/python/sglang/srt/misc/deepseek_eplb.py b/python/sglang/srt/misc/deepseek_eplb.py new file mode 100644 index 00000000000..99b3e2dfc0b --- /dev/null +++ b/python/sglang/srt/misc/deepseek_eplb.py @@ -0,0 +1,169 @@ +# This file is copied from https://github.com/deepseek-ai/EPLB/blob/main/eplb.py since that one is not a pypi package + +from typing import Tuple + +import torch + + +def balanced_packing(weight: torch.Tensor, num_packs: int) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Pack n weighted objects to m packs, such that each bin contains exactly n/m objects and the weights of all packs + are as balanced as possible. + + Parameters: + weight: [X, n], the weight of each item + num_packs: number of packs + + Returns: + pack_index: [X, n], the pack index of each item + rank_in_pack: [X, n], the rank of the item in the pack + """ + num_layers, num_groups = weight.shape + assert num_groups % num_packs == 0 + groups_per_pack = num_groups // num_packs + + if groups_per_pack == 1: + pack_index = torch.arange(weight.size(-1), dtype=torch.int64, device=weight.device).expand(weight.shape) + rank_in_pack = torch.zeros_like(weight, dtype=torch.int64) + return pack_index, rank_in_pack + + indices = weight.float().sort(-1, descending=True).indices.cpu() + pack_index = torch.full_like(weight, fill_value=-1, dtype=torch.int64, device='cpu') + rank_in_pack = torch.full_like(pack_index, fill_value=-1) + for i in range(num_layers): + pack_weights = [0] * num_packs + pack_items = [0] * num_packs + for group in indices[i]: + pack = min((i for i in range(num_packs) if pack_items[i] < groups_per_pack), + key=pack_weights.__getitem__) + assert pack_items[pack] < groups_per_pack + pack_index[i, group] = pack + rank_in_pack[i, group] = pack_items[pack] + pack_weights[pack] += weight[i, group] + pack_items[pack] += 1 + return pack_index, rank_in_pack + + +def replicate_experts(weight: torch.Tensor, num_phy: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Replicate `num_log` experts to `num_phy` replicas, such that the maximum load of all replicas is minimized. + + Parameters: + weight: [X, num_log] + num_phy: total number of experts after replication + + Returns: + phy2log: [X, num_phy], logical expert id of each physical expert + rank: [X, num_phy], the replica rank + logcnt: [X, num_log], number of replicas for each logical expert + """ + n, num_log = weight.shape + num_redundant = num_phy - num_log + assert num_redundant >= 0 + device = weight.device + phy2log = torch.arange(num_phy, dtype=torch.int64, device=device).repeat(n, 1) + rank = torch.zeros(n, num_phy, dtype=torch.int64, device=device) + logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device) + arangen = torch.arange(n, dtype=torch.int64, device=device) + for i in range(num_log, num_phy): + redundant_indices = (weight / logcnt).max(dim=-1).indices + phy2log[:, i] = redundant_indices + rank[:, i] = logcnt[arangen, redundant_indices] + logcnt[arangen, redundant_indices] += 1 + return phy2log, rank, logcnt + + +def rebalance_experts_hierarchical(weight: torch.Tensor, num_physical_experts: int, + num_groups: int, num_nodes: int, num_gpus: int): + """ + Parameters: + weight: [num_moe_layers, num_logical_experts] + num_physical_experts: number of physical experts after replication + num_groups: number of expert groups + num_nodes: number of server nodes, where the intra-node network (e.g, NVLink) is faster + num_gpus: number of GPUs, must be a multiple of `num_nodes` + + Returns: + physical_to_logical_map: [num_moe_layers, num_physical_experts] + logical_to_physical_map: [num_moe_layers, num_logical_experts, X] + logical_count: [num_moe_layers, num_logical_experts] + """ + num_layers, num_logical_experts = weight.shape + assert num_logical_experts % num_groups == 0 + group_size = num_logical_experts // num_groups + assert num_groups % num_nodes == 0 + groups_per_node = num_groups // num_nodes + assert num_gpus % num_nodes == 0 + assert num_physical_experts % num_gpus == 0 + phy_experts_per_gpu = num_physical_experts // num_gpus + + def inverse(perm: torch.Tensor) -> torch.Tensor: + inv = torch.empty_like(perm) + inv.scatter_(1, perm, torch.arange(perm.size(1), dtype=torch.int64, device=perm.device).expand(perm.shape)) + return inv + + # Step 1: pack groups to nodes + tokens_per_group = weight.unflatten(-1, (num_groups, group_size)).sum(-1) + group_pack_index, group_rank_in_pack = balanced_packing(tokens_per_group, num_nodes) + log2mlog = (((group_pack_index * groups_per_node + group_rank_in_pack) * group_size).unsqueeze(-1) + + torch.arange(group_size, dtype=torch.int64, device=group_pack_index.device)).flatten(-2) + mlog2log = inverse(log2mlog) + + # Step 2: construct redundant experts within nodes + # [num_layers * num_nodes, num_logical_experts // num_nodes] + tokens_per_mlog = weight.gather(-1, mlog2log).view(-1, num_logical_experts // num_nodes) + phy2mlog, phyrank, mlogcnt = replicate_experts(tokens_per_mlog, num_physical_experts // num_nodes) + + # Step 3: pack physical_experts to GPUs + # [num_layers * num_nodes, num_physical_experts // num_nodes] + tokens_per_phy = (tokens_per_mlog / mlogcnt).gather(-1, phy2mlog) + pack_index, rank_in_pack = balanced_packing(tokens_per_phy, num_gpus // num_nodes) + phy2pphy = pack_index * phy_experts_per_gpu + rank_in_pack + pphy2phy = inverse(phy2pphy) + + pphy2mlog = phy2mlog.gather(-1, pphy2phy) # [num_layers * num_nodes, num_log_per_nodes] + pphy2mlog = (pphy2mlog.view(num_layers, num_nodes, -1) + + torch.arange(0, num_logical_experts, num_logical_experts // num_nodes, + device=group_pack_index.device).view(1, -1, 1)).flatten(-2) + pphy2log = mlog2log.gather(-1, pphy2mlog) + pphyrank = phyrank.gather(-1, pphy2phy).view(num_layers, -1) + logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog) + return pphy2log, pphyrank, logcnt + + +def rebalance_experts(weight: torch.Tensor, num_replicas: int, num_groups: int, + num_nodes: int, num_gpus: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Entry point for expert-parallelism load balancer. + + Parameters: + weight: [layers, num_logical_experts], the load statistics for all logical experts + num_replicas: number of physical experts, must be a multiple of `num_gpus` + num_groups: number of expert groups + num_nodes: number of server nodes, where the intra-node network (e.g, NVLink) is faster + num_gpus: number of GPUs, must be a multiple of `num_nodes` + + Returns: + physical_to_logical_map: [layers, num_replicas], the expert index of each replica + logical_to_physical_map: [layers, num_logical_experts, X], the replica indices for each expert + expert_count: [layers, num_logical_experts], number of physical replicas for each logical expert + """ + num_layers, num_logical_experts = weight.shape + weight = weight.float().cpu() + if num_groups % num_nodes == 0: + # use hierarchical load-balance policy + phy2log, phyrank, logcnt = rebalance_experts_hierarchical(weight, num_replicas, + num_groups, num_nodes, num_gpus) + else: + # use global load-balance policy + phy2log, phyrank, logcnt = rebalance_experts_hierarchical(weight, num_replicas, 1, 1, num_gpus) + maxlogcnt = logcnt.max().item() + log2phy: torch.Tensor = torch.full((num_layers, num_logical_experts, maxlogcnt), + -1, dtype=torch.int64, device=logcnt.device) + log2phy.view(num_layers, -1).scatter_(-1, phy2log * maxlogcnt + phyrank, + torch.arange(num_replicas, dtype=torch.int64, device=log2phy.device).expand( + num_layers, -1)) + return phy2log, log2phy, logcnt + + +__all__ = ['rebalance_experts'] From 44246e23f18710667507322a3c104aaefe465dc7 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 17:19:08 +0800 Subject: [PATCH 0344/1089] more --- python/sglang/srt/managers/eplb_manager.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index f3973ec4d14..cc7a0416831 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -1,2 +1,3 @@ class EPLBManager: - pass + def rebalance_experts(self): + TODO From 63220f2e34c3bf1afb413990a584fd75d9aa8d6e Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 17:20:20 +0800 Subject: [PATCH 0345/1089] more --- python/sglang/srt/entrypoints/engine.py | 3 +++ python/sglang/srt/managers/eplb_manager.py | 2 +- python/sglang/srt/managers/tokenizer_manager.py | 4 ++++ 3 files changed, 8 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index b484e9564ae..322c1280525 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -301,6 +301,9 @@ def get_server_info(self): "version": __version__, } + def rebalance_experts(self): + self.tokenizer_manager.rebalance_experts() + def init_weights_update_group( self, master_address: str, diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index cc7a0416831..267d1142719 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -1,3 +1,3 @@ class EPLBManager: - def rebalance_experts(self): + async def rebalance_experts(self): TODO diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 3e626a15ff3..47c95685dea 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -707,6 +707,10 @@ async def _wait_for_model_update_from_disk( all_paused_requests = [r.num_paused_requests for r in result] return all_success, all_message, all_paused_requests + async def rebalance_experts(self): + self.auto_create_handle_loop() + await self.eplb_manager.rebalance_experts() + async def init_weights_update_group( self, obj: InitWeightsUpdateGroupReqInput, From a9aa0420eb94b8c46a6465cff1d0221ad26f142f Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 17:21:11 +0800 Subject: [PATCH 0346/1089] more --- python/sglang/srt/entrypoints/http_server.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index f43afec5423..19ca10ccef1 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -313,7 +313,7 @@ async def flush_cache(): _global_state.tokenizer_manager.flush_cache() return Response( content="Cache flushed.\nPlease check backend logs for more details. " - "(When there are running or waiting requests, the operation will not be performed.)\n", + "(When there are running or waiting requests, the operation will not be performed.)\n", status_code=200, ) @@ -370,6 +370,12 @@ async def dump_expert_distribution_record_async(): return ORJSONResponse(content, status_code=200) +@app.post("/rebalance_experts") +async def rebalance_experts(): + await _global_state.tokenizer_manager.rebalance_experts() + return ORJSONResponse({}, status_code=200) + + @app.post("/update_weights_from_disk") async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: Request): """Update the weights from disk inplace without re-launching the server.""" @@ -634,10 +640,10 @@ async def vertex_generate(vertex_req: VertexGenerateReqInput, raw_request: Reque ] break image_data = [ - instance.get("image_data") - for instance in vertex_req.instances - if instance.get("image_data") is not None - ] or None + instance.get("image_data") + for instance in vertex_req.instances + if instance.get("image_data") is not None + ] or None req = GenerateReqInput( **inputs, image_data=image_data, From b437345ed51ca356c5907b600e31e51eb3e6ac5a Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 17:21:38 +0800 Subject: [PATCH 0347/1089] fmt --- python/sglang/srt/entrypoints/http_server.py | 10 +- .../sglang/srt/managers/tokenizer_manager.py | 5 +- python/sglang/srt/misc/deepseek_eplb.py | 107 +++++++++++++----- 3 files changed, 87 insertions(+), 35 deletions(-) diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 19ca10ccef1..71c30d38370 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -313,7 +313,7 @@ async def flush_cache(): _global_state.tokenizer_manager.flush_cache() return Response( content="Cache flushed.\nPlease check backend logs for more details. " - "(When there are running or waiting requests, the operation will not be performed.)\n", + "(When there are running or waiting requests, the operation will not be performed.)\n", status_code=200, ) @@ -640,10 +640,10 @@ async def vertex_generate(vertex_req: VertexGenerateReqInput, raw_request: Reque ] break image_data = [ - instance.get("image_data") - for instance in vertex_req.instances - if instance.get("image_data") is not None - ] or None + instance.get("image_data") + for instance in vertex_req.instances + if instance.get("image_data") is not None + ] or None req = GenerateReqInput( **inputs, image_data=image_data, diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 47c95685dea..c669aef2daf 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -45,6 +45,7 @@ import zmq import zmq.asyncio from fastapi import BackgroundTasks + from sglang.srt.aio_rwlock import RWLock from sglang.srt.configs.model_config import ModelConfig from sglang.srt.disaggregation.conn import KVBootstrapServer @@ -968,8 +969,8 @@ def _handle_batch_output( elif isinstance(recv_obj, BatchTokenIDOut): if self.server_args.stream_output and state.obj.stream: output_token_ids = recv_obj.output_ids[i][ - state.last_output_offset: - ] + state.last_output_offset : + ] state.last_output_offset = len(recv_obj.output_ids[i]) else: output_token_ids = recv_obj.output_ids[i] diff --git a/python/sglang/srt/misc/deepseek_eplb.py b/python/sglang/srt/misc/deepseek_eplb.py index 99b3e2dfc0b..e5875b21342 100644 --- a/python/sglang/srt/misc/deepseek_eplb.py +++ b/python/sglang/srt/misc/deepseek_eplb.py @@ -5,7 +5,9 @@ import torch -def balanced_packing(weight: torch.Tensor, num_packs: int) -> Tuple[torch.Tensor, torch.Tensor]: +def balanced_packing( + weight: torch.Tensor, num_packs: int +) -> Tuple[torch.Tensor, torch.Tensor]: """ Pack n weighted objects to m packs, such that each bin contains exactly n/m objects and the weights of all packs are as balanced as possible. @@ -23,19 +25,23 @@ def balanced_packing(weight: torch.Tensor, num_packs: int) -> Tuple[torch.Tensor groups_per_pack = num_groups // num_packs if groups_per_pack == 1: - pack_index = torch.arange(weight.size(-1), dtype=torch.int64, device=weight.device).expand(weight.shape) + pack_index = torch.arange( + weight.size(-1), dtype=torch.int64, device=weight.device + ).expand(weight.shape) rank_in_pack = torch.zeros_like(weight, dtype=torch.int64) return pack_index, rank_in_pack indices = weight.float().sort(-1, descending=True).indices.cpu() - pack_index = torch.full_like(weight, fill_value=-1, dtype=torch.int64, device='cpu') + pack_index = torch.full_like(weight, fill_value=-1, dtype=torch.int64, device="cpu") rank_in_pack = torch.full_like(pack_index, fill_value=-1) for i in range(num_layers): pack_weights = [0] * num_packs pack_items = [0] * num_packs for group in indices[i]: - pack = min((i for i in range(num_packs) if pack_items[i] < groups_per_pack), - key=pack_weights.__getitem__) + pack = min( + (i for i in range(num_packs) if pack_items[i] < groups_per_pack), + key=pack_weights.__getitem__, + ) assert pack_items[pack] < groups_per_pack pack_index[i, group] = pack rank_in_pack[i, group] = pack_items[pack] @@ -44,7 +50,9 @@ def balanced_packing(weight: torch.Tensor, num_packs: int) -> Tuple[torch.Tensor return pack_index, rank_in_pack -def replicate_experts(weight: torch.Tensor, num_phy: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +def replicate_experts( + weight: torch.Tensor, num_phy: int +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Replicate `num_log` experts to `num_phy` replicas, such that the maximum load of all replicas is minimized. @@ -73,8 +81,13 @@ def replicate_experts(weight: torch.Tensor, num_phy: int) -> Tuple[torch.Tensor, return phy2log, rank, logcnt -def rebalance_experts_hierarchical(weight: torch.Tensor, num_physical_experts: int, - num_groups: int, num_nodes: int, num_gpus: int): +def rebalance_experts_hierarchical( + weight: torch.Tensor, + num_physical_experts: int, + num_groups: int, + num_nodes: int, + num_gpus: int, +): """ Parameters: weight: [num_moe_layers, num_logical_experts] @@ -99,20 +112,34 @@ def rebalance_experts_hierarchical(weight: torch.Tensor, num_physical_experts: i def inverse(perm: torch.Tensor) -> torch.Tensor: inv = torch.empty_like(perm) - inv.scatter_(1, perm, torch.arange(perm.size(1), dtype=torch.int64, device=perm.device).expand(perm.shape)) + inv.scatter_( + 1, + perm, + torch.arange(perm.size(1), dtype=torch.int64, device=perm.device).expand( + perm.shape + ), + ) return inv # Step 1: pack groups to nodes tokens_per_group = weight.unflatten(-1, (num_groups, group_size)).sum(-1) group_pack_index, group_rank_in_pack = balanced_packing(tokens_per_group, num_nodes) - log2mlog = (((group_pack_index * groups_per_node + group_rank_in_pack) * group_size).unsqueeze(-1) + - torch.arange(group_size, dtype=torch.int64, device=group_pack_index.device)).flatten(-2) + log2mlog = ( + ( + (group_pack_index * groups_per_node + group_rank_in_pack) * group_size + ).unsqueeze(-1) + + torch.arange(group_size, dtype=torch.int64, device=group_pack_index.device) + ).flatten(-2) mlog2log = inverse(log2mlog) # Step 2: construct redundant experts within nodes # [num_layers * num_nodes, num_logical_experts // num_nodes] - tokens_per_mlog = weight.gather(-1, mlog2log).view(-1, num_logical_experts // num_nodes) - phy2mlog, phyrank, mlogcnt = replicate_experts(tokens_per_mlog, num_physical_experts // num_nodes) + tokens_per_mlog = weight.gather(-1, mlog2log).view( + -1, num_logical_experts // num_nodes + ) + phy2mlog, phyrank, mlogcnt = replicate_experts( + tokens_per_mlog, num_physical_experts // num_nodes + ) # Step 3: pack physical_experts to GPUs # [num_layers * num_nodes, num_physical_experts // num_nodes] @@ -121,18 +148,31 @@ def inverse(perm: torch.Tensor) -> torch.Tensor: phy2pphy = pack_index * phy_experts_per_gpu + rank_in_pack pphy2phy = inverse(phy2pphy) - pphy2mlog = phy2mlog.gather(-1, pphy2phy) # [num_layers * num_nodes, num_log_per_nodes] - pphy2mlog = (pphy2mlog.view(num_layers, num_nodes, -1) + - torch.arange(0, num_logical_experts, num_logical_experts // num_nodes, - device=group_pack_index.device).view(1, -1, 1)).flatten(-2) + pphy2mlog = phy2mlog.gather( + -1, pphy2phy + ) # [num_layers * num_nodes, num_log_per_nodes] + pphy2mlog = ( + pphy2mlog.view(num_layers, num_nodes, -1) + + torch.arange( + 0, + num_logical_experts, + num_logical_experts // num_nodes, + device=group_pack_index.device, + ).view(1, -1, 1) + ).flatten(-2) pphy2log = mlog2log.gather(-1, pphy2mlog) pphyrank = phyrank.gather(-1, pphy2phy).view(num_layers, -1) logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog) return pphy2log, pphyrank, logcnt -def rebalance_experts(weight: torch.Tensor, num_replicas: int, num_groups: int, - num_nodes: int, num_gpus: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +def rebalance_experts( + weight: torch.Tensor, + num_replicas: int, + num_groups: int, + num_nodes: int, + num_gpus: int, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Entry point for expert-parallelism load balancer. @@ -152,18 +192,29 @@ def rebalance_experts(weight: torch.Tensor, num_replicas: int, num_groups: int, weight = weight.float().cpu() if num_groups % num_nodes == 0: # use hierarchical load-balance policy - phy2log, phyrank, logcnt = rebalance_experts_hierarchical(weight, num_replicas, - num_groups, num_nodes, num_gpus) + phy2log, phyrank, logcnt = rebalance_experts_hierarchical( + weight, num_replicas, num_groups, num_nodes, num_gpus + ) else: # use global load-balance policy - phy2log, phyrank, logcnt = rebalance_experts_hierarchical(weight, num_replicas, 1, 1, num_gpus) + phy2log, phyrank, logcnt = rebalance_experts_hierarchical( + weight, num_replicas, 1, 1, num_gpus + ) maxlogcnt = logcnt.max().item() - log2phy: torch.Tensor = torch.full((num_layers, num_logical_experts, maxlogcnt), - -1, dtype=torch.int64, device=logcnt.device) - log2phy.view(num_layers, -1).scatter_(-1, phy2log * maxlogcnt + phyrank, - torch.arange(num_replicas, dtype=torch.int64, device=log2phy.device).expand( - num_layers, -1)) + log2phy: torch.Tensor = torch.full( + (num_layers, num_logical_experts, maxlogcnt), + -1, + dtype=torch.int64, + device=logcnt.device, + ) + log2phy.view(num_layers, -1).scatter_( + -1, + phy2log * maxlogcnt + phyrank, + torch.arange(num_replicas, dtype=torch.int64, device=log2phy.device).expand( + num_layers, -1 + ), + ) return phy2log, log2phy, logcnt -__all__ = ['rebalance_experts'] +__all__ = ["rebalance_experts"] From 9ea883b149c4a370aa0cf89c5cfbe659e52e3bbf Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 17:25:47 +0800 Subject: [PATCH 0348/1089] mv --- python/sglang/srt/{misc => managers}/deepseek_eplb.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename python/sglang/srt/{misc => managers}/deepseek_eplb.py (100%) diff --git a/python/sglang/srt/misc/deepseek_eplb.py b/python/sglang/srt/managers/deepseek_eplb.py similarity index 100% rename from python/sglang/srt/misc/deepseek_eplb.py rename to python/sglang/srt/managers/deepseek_eplb.py From 4eb1936f4daf825ca6b90b2e5eac52352e87e364 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 17:26:07 +0800 Subject: [PATCH 0349/1089] more --- python/sglang/srt/managers/expert_distribution_recorder.py | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 python/sglang/srt/managers/expert_distribution_recorder.py diff --git a/python/sglang/srt/managers/expert_distribution_recorder.py b/python/sglang/srt/managers/expert_distribution_recorder.py new file mode 100644 index 00000000000..ab63f3f0508 --- /dev/null +++ b/python/sglang/srt/managers/expert_distribution_recorder.py @@ -0,0 +1,2 @@ +class ExpertDistributionRecorder: + TODO From 728bae503605d9657f229fcfa55c476c642b5985 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 17:26:34 +0800 Subject: [PATCH 0350/1089] more --- .../sglang/srt/managers/expert_distribution_recorder.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_distribution_recorder.py b/python/sglang/srt/managers/expert_distribution_recorder.py index ab63f3f0508..6cb39dadee6 100644 --- a/python/sglang/srt/managers/expert_distribution_recorder.py +++ b/python/sglang/srt/managers/expert_distribution_recorder.py @@ -1,2 +1,9 @@ class ExpertDistributionRecorder: - TODO + def __init__(self): + TODO + + def save_current(self): + TODO + + def get_last_snapshot(self): + return TODO From 1b0070041e8ac60342e5b2c91b17e02cf77027e8 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 17:27:04 +0800 Subject: [PATCH 0351/1089] more --- python/sglang/srt/managers/eplb_manager.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index 267d1142719..e286ecba4f2 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -1,3 +1,5 @@ class EPLBManager: + def __init__(self): + self._expert_distribution_recorder= ExpertDistributionRecorder() async def rebalance_experts(self): TODO From 82d0d38f0090b6a9293bd4fbf321b7092a74736c Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 17:27:26 +0800 Subject: [PATCH 0352/1089] more --- python/sglang/srt/managers/eplb_manager.py | 6 +++++- ...tribution_recorder.py => expert_distribution_storage.py} | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) rename python/sglang/srt/managers/{expert_distribution_recorder.py => expert_distribution_storage.py} (79%) diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index e286ecba4f2..4ba16e7c4ba 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -1,5 +1,9 @@ +from sglang.srt.managers.expert_distribution_storage import ExpertDistributionStorage + + class EPLBManager: def __init__(self): - self._expert_distribution_recorder= ExpertDistributionRecorder() + self._expert_distribution_storage = ExpertDistributionStorage() + async def rebalance_experts(self): TODO diff --git a/python/sglang/srt/managers/expert_distribution_recorder.py b/python/sglang/srt/managers/expert_distribution_storage.py similarity index 79% rename from python/sglang/srt/managers/expert_distribution_recorder.py rename to python/sglang/srt/managers/expert_distribution_storage.py index 6cb39dadee6..bc857888398 100644 --- a/python/sglang/srt/managers/expert_distribution_recorder.py +++ b/python/sglang/srt/managers/expert_distribution_storage.py @@ -1,4 +1,4 @@ -class ExpertDistributionRecorder: +class ExpertDistributionStorage: def __init__(self): TODO From c3479c01053dfc1048e14c29f898bd7e1b5bacf9 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 17:28:15 +0800 Subject: [PATCH 0353/1089] more --- python/sglang/srt/managers/eplb_manager.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index 4ba16e7c4ba..ae4d1878522 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -6,4 +6,6 @@ def __init__(self): self._expert_distribution_storage = ExpertDistributionStorage() async def rebalance_experts(self): + TODO_may_or_may_not_save_current + self._expert_distribution_storage.get_last_snapshot() TODO From 42116e10607fca4b85d897f950b1c6e3225af147 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 17:29:19 +0800 Subject: [PATCH 0354/1089] more --- python/sglang/srt/managers/eplb_manager.py | 9 ++++++++- python/sglang/srt/managers/tokenizer_manager.py | 11 +++++++---- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index ae4d1878522..a387dcaa50a 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -1,11 +1,18 @@ +from typing import TYPE_CHECKING + from sglang.srt.managers.expert_distribution_storage import ExpertDistributionStorage +if TYPE_CHECKING: + from sglang.srt.managers.tokenizer_manager import TokenizerManager + class EPLBManager: - def __init__(self): + def __init__(self, tokenizer_manager: TokenizerManager): + self._tokenizer_manager = tokenizer_manager self._expert_distribution_storage = ExpertDistributionStorage() async def rebalance_experts(self): TODO_may_or_may_not_save_current self._expert_distribution_storage.get_last_snapshot() TODO + await self._tokenizer_manager.update_expert_location_metadata(TODO) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index c669aef2daf..65d0a315dc0 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -45,7 +45,6 @@ import zmq import zmq.asyncio from fastapi import BackgroundTasks - from sglang.srt.aio_rwlock import RWLock from sglang.srt.configs.model_config import ModelConfig from sglang.srt.disaggregation.conn import KVBootstrapServer @@ -202,7 +201,7 @@ def __init__( revision=server_args.revision, ) - self.eplb_manager = EPLBManager() + self.eplb_manager = EPLBManager(self) # Store states self.no_create_loop = False @@ -712,6 +711,10 @@ async def rebalance_experts(self): self.auto_create_handle_loop() await self.eplb_manager.rebalance_experts() + async def update_expert_location_metadata(self): + self.auto_create_handle_loop() + TODO + async def init_weights_update_group( self, obj: InitWeightsUpdateGroupReqInput, @@ -969,8 +972,8 @@ def _handle_batch_output( elif isinstance(recv_obj, BatchTokenIDOut): if self.server_args.stream_output and state.obj.stream: output_token_ids = recv_obj.output_ids[i][ - state.last_output_offset : - ] + state.last_output_offset: + ] state.last_output_offset = len(recv_obj.output_ids[i]) else: output_token_ids = recv_obj.output_ids[i] From 4d772bd879cf5b687693430a144a5bc5dfd08513 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 17:29:45 +0800 Subject: [PATCH 0355/1089] more --- python/sglang/srt/managers/eplb_manager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index a387dcaa50a..dbf2da204bf 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -14,5 +14,5 @@ def __init__(self, tokenizer_manager: TokenizerManager): async def rebalance_experts(self): TODO_may_or_may_not_save_current self._expert_distribution_storage.get_last_snapshot() - TODO - await self._tokenizer_manager.update_expert_location_metadata(TODO) + expert_location_metadata = TODO + await self._tokenizer_manager.update_expert_location_metadata(expert_location_metadata) From 11c08b03d94986323024eeebe6721e2b3b79f7a3 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 17:31:27 +0800 Subject: [PATCH 0356/1089] more --- python/sglang/srt/managers/eplb_manager.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index dbf2da204bf..f0a3c819d43 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -1,5 +1,6 @@ from typing import TYPE_CHECKING +from sglang.srt.managers import deepseek_eplb from sglang.srt.managers.expert_distribution_storage import ExpertDistributionStorage if TYPE_CHECKING: @@ -16,3 +17,14 @@ async def rebalance_experts(self): self._expert_distribution_storage.get_last_snapshot() expert_location_metadata = TODO await self._tokenizer_manager.update_expert_location_metadata(expert_location_metadata) + + def get_expert_location_metadata(self): + logical_count = self._expert_distribution_storage.get_last_snapshot() + physical_to_logical_map, logical_to_physical_map, expert_count = deepseek_eplb.rebalance_experts( + weight=logical_count, + num_replicas=TODO, + num_groups=TODO, + num_nodes=TODO, + num_gpus=TODO, + ) + return TODO From a997a46add6d65adecb5667dc4a24036ef5f3e42 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 17:31:48 +0800 Subject: [PATCH 0357/1089] more --- python/sglang/srt/managers/eplb_manager.py | 24 +++++++++++----------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index f0a3c819d43..ed703911683 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -14,17 +14,17 @@ def __init__(self, tokenizer_manager: TokenizerManager): async def rebalance_experts(self): TODO_may_or_may_not_save_current - self._expert_distribution_storage.get_last_snapshot() - expert_location_metadata = TODO + logical_count = self._expert_distribution_storage.get_last_snapshot() + expert_location_metadata = _compute_expert_location_metadata(logical_count) await self._tokenizer_manager.update_expert_location_metadata(expert_location_metadata) - def get_expert_location_metadata(self): - logical_count = self._expert_distribution_storage.get_last_snapshot() - physical_to_logical_map, logical_to_physical_map, expert_count = deepseek_eplb.rebalance_experts( - weight=logical_count, - num_replicas=TODO, - num_groups=TODO, - num_nodes=TODO, - num_gpus=TODO, - ) - return TODO + +def _compute_expert_location_metadata(logical_count): + physical_to_logical_map, logical_to_physical_map, expert_count = deepseek_eplb.rebalance_experts( + weight=logical_count, + num_replicas=TODO, + num_groups=TODO, + num_nodes=TODO, + num_gpus=TODO, + ) + return TODO From a3783f09010034cdb117ce421bfe3e54eb579f32 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 17:32:00 +0800 Subject: [PATCH 0358/1089] more --- python/sglang/srt/managers/eplb_manager.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index ed703911683..a164eaadcaf 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -2,6 +2,7 @@ from sglang.srt.managers import deepseek_eplb from sglang.srt.managers.expert_distribution_storage import ExpertDistributionStorage +from sglang.srt.managers.expert_location import ExpertLocationMetadata if TYPE_CHECKING: from sglang.srt.managers.tokenizer_manager import TokenizerManager @@ -27,4 +28,6 @@ def _compute_expert_location_metadata(logical_count): num_nodes=TODO, num_gpus=TODO, ) - return TODO + return ExpertLocationMetadata( + TODO=TODO, + ) From 7e78dfef61fe9e43d91815fe3b8b0a0beb622176 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 17:32:40 +0800 Subject: [PATCH 0359/1089] more --- python/sglang/srt/managers/eplb_manager.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index a164eaadcaf..436135c3cc0 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -1,8 +1,10 @@ from typing import TYPE_CHECKING +import torch from sglang.srt.managers import deepseek_eplb from sglang.srt.managers.expert_distribution_storage import ExpertDistributionStorage from sglang.srt.managers.expert_location import ExpertLocationMetadata +from sglang.srt.server_args import ServerArgs if TYPE_CHECKING: from sglang.srt.managers.tokenizer_manager import TokenizerManager @@ -20,12 +22,12 @@ async def rebalance_experts(self): await self._tokenizer_manager.update_expert_location_metadata(expert_location_metadata) -def _compute_expert_location_metadata(logical_count): +def _compute_expert_location_metadata(server_args: ServerArgs, logical_count: torch.Tensor): physical_to_logical_map, logical_to_physical_map, expert_count = deepseek_eplb.rebalance_experts( weight=logical_count, num_replicas=TODO, num_groups=TODO, - num_nodes=TODO, + num_nodes=server_args.nnodes, num_gpus=TODO, ) return ExpertLocationMetadata( From 78a2138f2b6d8fb8e8ba08458fc7648797698cda Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 17:33:03 +0800 Subject: [PATCH 0360/1089] more --- python/sglang/srt/managers/eplb_manager.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index 436135c3cc0..1c3044f42a0 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -28,7 +28,8 @@ def _compute_expert_location_metadata(server_args: ServerArgs, logical_count: to num_replicas=TODO, num_groups=TODO, num_nodes=server_args.nnodes, - num_gpus=TODO, + # TODO Consider scenario when disabling DP attn + DP size > 1 + num_gpus=server_args.tp_size, ) return ExpertLocationMetadata( TODO=TODO, From fb5e00f2578163e59cf7a59a6d5fb90debe0a982 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 17:33:35 +0800 Subject: [PATCH 0361/1089] more --- python/sglang/srt/managers/eplb_manager.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index 1c3044f42a0..899e13012cb 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -22,6 +22,7 @@ async def rebalance_experts(self): await self._tokenizer_manager.update_expert_location_metadata(expert_location_metadata) +# TODO maybe move to ExpertLocationMetadata static method? def _compute_expert_location_metadata(server_args: ServerArgs, logical_count: torch.Tensor): physical_to_logical_map, logical_to_physical_map, expert_count = deepseek_eplb.rebalance_experts( weight=logical_count, From 2c4c689f3dcb3d5633b997e28e2cfe83f1eae0f7 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 17:34:05 +0800 Subject: [PATCH 0362/1089] more --- python/sglang/srt/managers/eplb_manager.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index 899e13012cb..2a49516d08b 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -33,5 +33,10 @@ def _compute_expert_location_metadata(server_args: ServerArgs, logical_count: to num_gpus=server_args.tp_size, ) return ExpertLocationMetadata( - TODO=TODO, + num_layers=TODO, + num_local_physical_experts=TODO, + num_logical_experts=TODO, + physical_to_logical_map=physical_to_logical_map, + logical_to_all_physical_map=logical_to_physical_map, + logical_to_rank_dispatch_physical_map=TODO, ) From 804d52bb91632ec77b20eb92d3bf09eb62fb0b7f Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 17:36:24 +0800 Subject: [PATCH 0363/1089] more --- python/sglang/srt/managers/eplb_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index 2a49516d08b..88143702b28 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -27,7 +27,7 @@ def _compute_expert_location_metadata(server_args: ServerArgs, logical_count: to physical_to_logical_map, logical_to_physical_map, expert_count = deepseek_eplb.rebalance_experts( weight=logical_count, num_replicas=TODO, - num_groups=TODO, + num_groups=config.n_group, num_nodes=server_args.nnodes, # TODO Consider scenario when disabling DP attn + DP size > 1 num_gpus=server_args.tp_size, From 8d958258c2004e5b10186f369ab2b869a1777b36 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 17:40:02 +0800 Subject: [PATCH 0364/1089] more --- python/sglang/srt/managers/expert_location.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 825e7314c72..151d2fa826b 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -2,7 +2,6 @@ from typing import List import torch - from sglang.srt.configs.model_config import ModelConfig from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.managers.schedule_batch import global_server_args_dict @@ -31,7 +30,7 @@ def from_model_config(model_config: ModelConfig): @staticmethod def init_new(num_layers: int, num_logical_experts: int): num_physical_experts = ( - num_logical_experts + global_server_args_dict["ep_num_redundant_experts"] + num_logical_experts + global_server_args_dict["ep_num_redundant_experts"] ) world_size = get_tensor_model_parallel_world_size() assert num_physical_experts % world_size == 0 @@ -44,7 +43,7 @@ def init_new(num_layers: int, num_logical_experts: int): physical_to_logical_map=torch.arange(0, num_physical_experts).repeat( num_layers, 1 ) - % num_logical_experts, + % num_logical_experts, # Throw away the redundant experts here - highly inefficient, but we do not care since we will # use EPLB distribution logic logical_to_all_physical_map=torch.arange(0, num_logical_experts).repeat( @@ -68,7 +67,7 @@ def physical_to_local_physical(self, global_physical_expert_index: int): return global_physical_expert_index % self.num_local_physical_experts def logical_to_all_physical( - self, layer_id: int, logical_expert_id: int + self, layer_id: int, logical_expert_id: int ) -> List[int]: return [ physical_expert_id @@ -77,3 +76,10 @@ def logical_to_all_physical( ].tolist() if physical_expert_id != -1 ] + + +@dataclass +class ModelConfigForExpertLocation: + num_layers: int + num_logical_experts: int + num_groups: int From 84dd409f00fc455bea70e8c3567a6e506e498ca1 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 17:40:35 +0800 Subject: [PATCH 0365/1089] more --- python/sglang/srt/managers/expert_location.py | 4 ++-- python/sglang/srt/models/deepseek_v2.py | 7 ++++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 151d2fa826b..de38a5b771c 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import List +from typing import List, Optional import torch from sglang.srt.configs.model_config import ModelConfig @@ -82,4 +82,4 @@ def logical_to_all_physical( class ModelConfigForExpertLocation: num_layers: int num_logical_experts: int - num_groups: int + num_groups: Optional[int] = None diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index d0bb51f05ff..e9ff0548e48 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -69,7 +69,7 @@ VocabParallelEmbedding, ) from sglang.srt.managers.expert_distribution import expert_distribution_recorder -from sglang.srt.managers.expert_location import ExpertLocationMetadata +from sglang.srt.managers.expert_location import ExpertLocationMetadata, ModelConfigForExpertLocation from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_loader.weight_utils import default_weight_loader @@ -1614,10 +1614,11 @@ def set_embed_and_head(self, embed, head): torch.cuda.synchronize() @classmethod - def get_expert_location_metadata(cls, config): - return ExpertLocationMetadata.init_new( + def get_model_config_for_expert_location(cls, config): + return ModelConfigForExpertLocation( num_layers=config.num_hidden_layers, num_logical_experts=config.n_routed_experts, + num_groups=config.n_groups, ) From 5f0e0bc1d9f01553cf86982729d0d82396b06d40 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 17:40:58 +0800 Subject: [PATCH 0366/1089] more --- python/sglang/srt/models/qwen2_moe.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index d724d1afa65..47b9850add6 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -20,9 +20,6 @@ import torch import torch.nn.functional as F -from torch import nn -from transformers import PretrainedConfig - from sglang.srt.distributed import ( get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, @@ -45,10 +42,12 @@ VocabParallelEmbedding, ) from sglang.srt.managers.expert_distribution import expert_distribution_recorder -from sglang.srt.managers.expert_location import ExpertLocationMetadata +from sglang.srt.managers.expert_location import ExpertLocationMetadata, ModelConfigForExpertLocation from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import add_prefix +from torch import nn +from transformers import PretrainedConfig class Qwen2MoeMLP(nn.Module): @@ -195,7 +194,7 @@ def __init__( self.head_dim = hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim - self.scaling = self.head_dim**-0.5 + self.scaling = self.head_dim ** -0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings @@ -488,10 +487,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader(param, loaded_weight) @classmethod - def get_expert_location_metadata(cls, config): - return ExpertLocationMetadata.init_new( + def get_model_config_for_expert_location(cls, config): + return ModelConfigForExpertLocation( num_layers=config.num_hidden_layers, num_logical_experts=config.num_experts, + num_groups=None, ) From db52e916984a921d0054f115fe53db7f8b0969ad Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 17:41:40 +0800 Subject: [PATCH 0367/1089] more --- python/sglang/srt/managers/expert_location.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index de38a5b771c..1ae6263196a 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -23,8 +23,9 @@ class ExpertLocationMetadata: @staticmethod def from_model_config(model_config: ModelConfig): model_class, _ = get_model_architecture(model_config) - if hasattr(model_class, "get_expert_location_metadata"): - return model_class.get_expert_location_metadata(model_config.hf_config) + if hasattr(model_class, "get_model_config_for_expert_location"): + model_config_for_expert_location = model_class.get_model_config_for_expert_location(model_config.hf_config) + return TODO return ExpertLocationMetadata._init_dummy() @staticmethod From f5ee24d43e358fdeda890c9c73ff4db8e9a3cb7d Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 17:42:07 +0800 Subject: [PATCH 0368/1089] more --- python/sglang/srt/managers/expert_location.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 1ae6263196a..b73e7a71fa0 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -22,6 +22,7 @@ class ExpertLocationMetadata: @staticmethod def from_model_config(model_config: ModelConfig): + TODO_this_function model_class, _ = get_model_architecture(model_config) if hasattr(model_class, "get_model_config_for_expert_location"): model_config_for_expert_location = model_class.get_model_config_for_expert_location(model_config.hf_config) From af4e8f3e0f585df74d0193453a6f2ee44bb496fd Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 17:50:18 +0800 Subject: [PATCH 0369/1089] more --- python/sglang/srt/managers/expert_location.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index b73e7a71fa0..60b64926fa5 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -26,8 +26,9 @@ def from_model_config(model_config: ModelConfig): model_class, _ = get_model_architecture(model_config) if hasattr(model_class, "get_model_config_for_expert_location"): model_config_for_expert_location = model_class.get_model_config_for_expert_location(model_config.hf_config) - return TODO - return ExpertLocationMetadata._init_dummy() + else: + model_config_for_expert_location = ModelConfigForExpertLocation.init_dummy() + return TODO @staticmethod def init_new(num_layers: int, num_logical_experts: int): @@ -56,10 +57,6 @@ def init_new(num_layers: int, num_logical_experts: int): ).repeat(num_layers, 1)[..., None], ) - @staticmethod - def _init_dummy(): - return ExpertLocationMetadata.init_new(num_layers=1, num_logical_experts=1) - # -------------------------------- usage ------------------------------------ def local_physical_to_physical(self, rank: int, local_physical_expert_index: int): @@ -85,3 +82,7 @@ class ModelConfigForExpertLocation: num_layers: int num_logical_experts: int num_groups: Optional[int] = None + + @staticmethod + def init_dummy(): + return ModelConfigForExpertLocation(num_layers=1, num_logical_experts=1) From 022bbfee0bf060132c06f25abf226a399c3a7800 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 17:51:26 +0800 Subject: [PATCH 0370/1089] more --- python/sglang/srt/managers/expert_location.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 60b64926fa5..a0bdfade517 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -31,7 +31,8 @@ def from_model_config(model_config: ModelConfig): return TODO @staticmethod - def init_new(num_layers: int, num_logical_experts: int): + def init_trivial(num_layers: int, num_logical_experts: int): + """Trivial location - logical expert i corresponds to physical expert i""" num_physical_experts = ( num_logical_experts + global_server_args_dict["ep_num_redundant_experts"] ) From 62765988632276f39d3fae9ad3751584babf0758 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 17:52:06 +0800 Subject: [PATCH 0371/1089] more --- python/sglang/srt/managers/expert_location.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index a0bdfade517..dfc317dbc08 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -22,12 +22,7 @@ class ExpertLocationMetadata: @staticmethod def from_model_config(model_config: ModelConfig): - TODO_this_function - model_class, _ = get_model_architecture(model_config) - if hasattr(model_class, "get_model_config_for_expert_location"): - model_config_for_expert_location = model_class.get_model_config_for_expert_location(model_config.hf_config) - else: - model_config_for_expert_location = ModelConfigForExpertLocation.init_dummy() + model_config_for_expert_location = ModelConfigForExpertLocation.from_model_config(model_config) return TODO @staticmethod @@ -87,3 +82,11 @@ class ModelConfigForExpertLocation: @staticmethod def init_dummy(): return ModelConfigForExpertLocation(num_layers=1, num_logical_experts=1) + + @staticmethod + def from_model_config(model_config: ModelConfig): + model_class, _ = get_model_architecture(model_config) + if hasattr(model_class, "get_model_config_for_expert_location"): + return model_class.get_model_config_for_expert_location(model_config.hf_config) + else: + return ModelConfigForExpertLocation.init_dummy() From 9b1a6f2cc9245e8169d1c855e0300efdc3ed31eb Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 17:53:30 +0800 Subject: [PATCH 0372/1089] more --- python/sglang/srt/managers/eplb_manager.py | 8 ++++++-- python/sglang/srt/managers/expert_location.py | 1 - 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index 88143702b28..3814c2fe13c 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -1,9 +1,11 @@ from typing import TYPE_CHECKING import torch + +from sglang.srt.configs.model_config import ModelConfig from sglang.srt.managers import deepseek_eplb from sglang.srt.managers.expert_distribution_storage import ExpertDistributionStorage -from sglang.srt.managers.expert_location import ExpertLocationMetadata +from sglang.srt.managers.expert_location import ExpertLocationMetadata, ModelConfigForExpertLocation from sglang.srt.server_args import ServerArgs if TYPE_CHECKING: @@ -24,6 +26,8 @@ async def rebalance_experts(self): # TODO maybe move to ExpertLocationMetadata static method? def _compute_expert_location_metadata(server_args: ServerArgs, logical_count: torch.Tensor): + model_config = ModelConfig.from_server_args(server_args) + model_config_for_expert_location = ModelConfigForExpertLocation.from_model_config(model_config) physical_to_logical_map, logical_to_physical_map, expert_count = deepseek_eplb.rebalance_experts( weight=logical_count, num_replicas=TODO, @@ -38,5 +42,5 @@ def _compute_expert_location_metadata(server_args: ServerArgs, logical_count: to num_logical_experts=TODO, physical_to_logical_map=physical_to_logical_map, logical_to_all_physical_map=logical_to_physical_map, - logical_to_rank_dispatch_physical_map=TODO, + logical_to_rank_dispatch_physical_map=TODO_compute, ) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index dfc317dbc08..351e5a8fcb4 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -22,7 +22,6 @@ class ExpertLocationMetadata: @staticmethod def from_model_config(model_config: ModelConfig): - model_config_for_expert_location = ModelConfigForExpertLocation.from_model_config(model_config) return TODO @staticmethod From a80af1bdbc38305ae538f173bf182a25c6abe137 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 17:53:50 +0800 Subject: [PATCH 0373/1089] more --- python/sglang/srt/managers/eplb_manager.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index 3814c2fe13c..1650ab8979e 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -31,15 +31,15 @@ def _compute_expert_location_metadata(server_args: ServerArgs, logical_count: to physical_to_logical_map, logical_to_physical_map, expert_count = deepseek_eplb.rebalance_experts( weight=logical_count, num_replicas=TODO, - num_groups=config.n_group, + num_groups=model_config_for_expert_location.num_groups, num_nodes=server_args.nnodes, # TODO Consider scenario when disabling DP attn + DP size > 1 num_gpus=server_args.tp_size, ) return ExpertLocationMetadata( - num_layers=TODO, + num_layers=model_config_for_expert_location.num_layers, num_local_physical_experts=TODO, - num_logical_experts=TODO, + num_logical_experts=model_config_for_expert_location.num_logical_experts, physical_to_logical_map=physical_to_logical_map, logical_to_all_physical_map=logical_to_physical_map, logical_to_rank_dispatch_physical_map=TODO_compute, From 9f0ba54050bf024fc6656e64405c265176ee7b69 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 17:54:45 +0800 Subject: [PATCH 0374/1089] more --- python/sglang/srt/managers/eplb_manager.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index 1650ab8979e..5031d23b9c9 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -1,7 +1,6 @@ from typing import TYPE_CHECKING import torch - from sglang.srt.configs.model_config import ModelConfig from sglang.srt.managers import deepseek_eplb from sglang.srt.managers.expert_distribution_storage import ExpertDistributionStorage @@ -42,5 +41,9 @@ def _compute_expert_location_metadata(server_args: ServerArgs, logical_count: to num_logical_experts=model_config_for_expert_location.num_logical_experts, physical_to_logical_map=physical_to_logical_map, logical_to_all_physical_map=logical_to_physical_map, - logical_to_rank_dispatch_physical_map=TODO_compute, + logical_to_rank_dispatch_physical_map=_compute_logical_to_rank_dispatch_physical_map(logical_to_physical_map), ) + + +def _compute_logical_to_rank_dispatch_physical_map(logical_to_physical_map: torch.Tensor): + return TODO From f7ae980e40040f1d6d91a5cab7bc421d2108228b Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 17:55:28 +0800 Subject: [PATCH 0375/1089] more --- python/sglang/srt/managers/eplb_manager.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index 5031d23b9c9..1ea81aa26f6 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -32,8 +32,7 @@ def _compute_expert_location_metadata(server_args: ServerArgs, logical_count: to num_replicas=TODO, num_groups=model_config_for_expert_location.num_groups, num_nodes=server_args.nnodes, - # TODO Consider scenario when disabling DP attn + DP size > 1 - num_gpus=server_args.tp_size, + num_gpus=world_size, ) return ExpertLocationMetadata( num_layers=model_config_for_expert_location.num_layers, From 211fc81854c03968d508f322d4d861eabf9b1ab2 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 17:55:51 +0800 Subject: [PATCH 0376/1089] more --- python/sglang/srt/managers/eplb_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index 1ea81aa26f6..6e8f3de2657 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -29,7 +29,7 @@ def _compute_expert_location_metadata(server_args: ServerArgs, logical_count: to model_config_for_expert_location = ModelConfigForExpertLocation.from_model_config(model_config) physical_to_logical_map, logical_to_physical_map, expert_count = deepseek_eplb.rebalance_experts( weight=logical_count, - num_replicas=TODO, + num_replicas=num_physical_experts, num_groups=model_config_for_expert_location.num_groups, num_nodes=server_args.nnodes, num_gpus=world_size, From 31692d776efa01f529de63480d1a123afcbb2a68 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 17:57:33 +0800 Subject: [PATCH 0377/1089] more --- python/sglang/srt/managers/eplb_manager.py | 7 +++++-- python/sglang/srt/managers/tokenizer_manager.py | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index 6e8f3de2657..5b26d15ee4f 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -18,10 +18,13 @@ def __init__(self, tokenizer_manager: TokenizerManager): async def rebalance_experts(self): TODO_may_or_may_not_save_current - logical_count = self._expert_distribution_storage.get_last_snapshot() - expert_location_metadata = _compute_expert_location_metadata(logical_count) + expert_location_metadata = self.get_expert_location_metadata() await self._tokenizer_manager.update_expert_location_metadata(expert_location_metadata) + def get_expert_location_metadata(self): + logical_count = self._expert_distribution_storage.get_last_snapshot() + return _compute_expert_location_metadata(self._tokenizer_manager.server_args, logical_count) + # TODO maybe move to ExpertLocationMetadata static method? def _compute_expert_location_metadata(server_args: ServerArgs, logical_count: torch.Tensor): diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 65d0a315dc0..a5dd70f963a 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -711,7 +711,7 @@ async def rebalance_experts(self): self.auto_create_handle_loop() await self.eplb_manager.rebalance_experts() - async def update_expert_location_metadata(self): + async def update_expert_location_metadata(self, expert_location_metadata: ExpertLocationMetadata): self.auto_create_handle_loop() TODO From 69c57776e61e0fa521211f0274ec0c40b847fc93 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 17:58:03 +0800 Subject: [PATCH 0378/1089] more --- python/sglang/srt/managers/eplb_manager.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index 5b26d15ee4f..c479468f186 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -12,18 +12,18 @@ class EPLBManager: - def __init__(self, tokenizer_manager: TokenizerManager): - self._tokenizer_manager = tokenizer_manager + def __init__(self, server_args: ServerArgs): + self._server_args = server_args self._expert_distribution_storage = ExpertDistributionStorage() async def rebalance_experts(self): TODO_may_or_may_not_save_current expert_location_metadata = self.get_expert_location_metadata() - await self._tokenizer_manager.update_expert_location_metadata(expert_location_metadata) + await self.tokenizer_manager.update_expert_location_metadata(expert_location_metadata) def get_expert_location_metadata(self): logical_count = self._expert_distribution_storage.get_last_snapshot() - return _compute_expert_location_metadata(self._tokenizer_manager.server_args, logical_count) + return _compute_expert_location_metadata(self._server_args, logical_count) # TODO maybe move to ExpertLocationMetadata static method? From 3af3fdcac8c13285f702e470c7dc3ba4be78a29e Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 17:58:57 +0800 Subject: [PATCH 0379/1089] more --- python/sglang/srt/managers/tokenizer_manager.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index a5dd70f963a..41386f55be1 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -140,6 +140,7 @@ def __init__( server_args: ServerArgs, port_args: PortArgs, expert_location_metadata: ExpertLocationMetadata, + eplb_manager: EPLBManager, ): # Parse args self.server_args = server_args @@ -201,7 +202,8 @@ def __init__( revision=server_args.revision, ) - self.eplb_manager = EPLBManager(self) + eplb_manager.tokenizer_manager = self + self.eplb_manager = eplb_manager # Store states self.no_create_loop = False From a46be2cbcff0ca3be8b1ccde79228d6f90de0e45 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 17:59:14 +0800 Subject: [PATCH 0380/1089] more --- python/sglang/srt/entrypoints/engine.py | 4 +++- python/sglang/srt/managers/eplb_manager.py | 1 + 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 322c1280525..f5473a8cc99 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -32,6 +32,7 @@ from PIL.Image import Image from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.managers.eplb_manager import EPLBManager from sglang.srt.managers.expert_location import ExpertLocationMetadata # Fix a bug of Python threading @@ -501,6 +502,7 @@ def _launch_subprocesses( server_args.model_path, server_args.tokenizer_path ) + eplb_manager = EPLBManager(server_args) model_config = ModelConfig.from_server_args(server_args) expert_location_metadata = ExpertLocationMetadata.from_model_config(model_config) @@ -583,7 +585,7 @@ def _launch_subprocesses( # Launch tokenizer process tokenizer_manager = TokenizerManager( - server_args, port_args, expert_location_metadata + server_args, port_args, expert_location_metadata, eplb_manager ) if server_args.chat_template: load_chat_template_for_openai_api( diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index c479468f186..258674187dc 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -15,6 +15,7 @@ class EPLBManager: def __init__(self, server_args: ServerArgs): self._server_args = server_args self._expert_distribution_storage = ExpertDistributionStorage() + self.tokenizer_manager: Optional[TokenizerManager] = None async def rebalance_experts(self): TODO_may_or_may_not_save_current From 0cbd4b47f4f240a2254c6591d9731bde25bfa826 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 17:59:41 +0800 Subject: [PATCH 0381/1089] more --- python/sglang/srt/entrypoints/engine.py | 2 +- python/sglang/srt/managers/eplb_manager.py | 2 +- python/sglang/srt/managers/expert_location.py | 4 ---- 3 files changed, 2 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index f5473a8cc99..c9c60241b67 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -504,7 +504,7 @@ def _launch_subprocesses( eplb_manager = EPLBManager(server_args) model_config = ModelConfig.from_server_args(server_args) - expert_location_metadata = ExpertLocationMetadata.from_model_config(model_config) + expert_location_metadata = eplb_manager.get_expert_location_metadata() scheduler_procs = [] if server_args.dp_size == 1: diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index 258674187dc..7da0c1f0ca7 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import torch from sglang.srt.configs.model_config import ModelConfig diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 351e5a8fcb4..5db25a17ba2 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -20,10 +20,6 @@ class ExpertLocationMetadata: # -------------------------------- construction and mutation ------------------------------------ - @staticmethod - def from_model_config(model_config: ModelConfig): - return TODO - @staticmethod def init_trivial(num_layers: int, num_logical_experts: int): """Trivial location - logical expert i corresponds to physical expert i""" From 1ba8466a426010206dee49966e3762dfbb8feb73 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 17:59:52 +0800 Subject: [PATCH 0382/1089] more --- python/sglang/srt/entrypoints/engine.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index c9c60241b67..69301d86012 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -503,7 +503,6 @@ def _launch_subprocesses( ) eplb_manager = EPLBManager(server_args) - model_config = ModelConfig.from_server_args(server_args) expert_location_metadata = eplb_manager.get_expert_location_metadata() scheduler_procs = [] From e3186cf1e98450c2b947baf9f59d40bd515d1b17 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 18:00:07 +0800 Subject: [PATCH 0383/1089] more --- python/sglang/srt/entrypoints/engine.py | 2 +- python/sglang/srt/managers/eplb_manager.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 69301d86012..b518e778efd 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -503,7 +503,7 @@ def _launch_subprocesses( ) eplb_manager = EPLBManager(server_args) - expert_location_metadata = eplb_manager.get_expert_location_metadata() + expert_location_metadata = eplb_manager.compute_expert_location_metadata() scheduler_procs = [] if server_args.dp_size == 1: diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index 7da0c1f0ca7..a66e1a36967 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -19,16 +19,16 @@ def __init__(self, server_args: ServerArgs): async def rebalance_experts(self): TODO_may_or_may_not_save_current - expert_location_metadata = self.get_expert_location_metadata() + expert_location_metadata = self.compute_expert_location_metadata() await self.tokenizer_manager.update_expert_location_metadata(expert_location_metadata) - def get_expert_location_metadata(self): + def compute_expert_location_metadata(self): logical_count = self._expert_distribution_storage.get_last_snapshot() - return _compute_expert_location_metadata(self._server_args, logical_count) + return _compute_expert_location_metadata_raw(self._server_args, logical_count) # TODO maybe move to ExpertLocationMetadata static method? -def _compute_expert_location_metadata(server_args: ServerArgs, logical_count: torch.Tensor): +def _compute_expert_location_metadata_raw(server_args: ServerArgs, logical_count: torch.Tensor): model_config = ModelConfig.from_server_args(server_args) model_config_for_expert_location = ModelConfigForExpertLocation.from_model_config(model_config) physical_to_logical_map, logical_to_physical_map, expert_count = deepseek_eplb.rebalance_experts( From 1582464b1c95f2f6a271829694830e0447e33aef Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 18:02:19 +0800 Subject: [PATCH 0384/1089] more --- python/sglang/srt/managers/eplb_manager.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index a66e1a36967..d29b21b7602 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -24,6 +24,8 @@ async def rebalance_experts(self): def compute_expert_location_metadata(self): logical_count = self._expert_distribution_storage.get_last_snapshot() + if logical_count is None: + return TODO_default return _compute_expert_location_metadata_raw(self._server_args, logical_count) From c43c110f7cecbfa2ef1e3b57a489d0dbf28d4065 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 18:03:11 +0800 Subject: [PATCH 0385/1089] more --- python/sglang/srt/server_args.py | 76 +++++++++++++++++--------------- 1 file changed, 41 insertions(+), 35 deletions(-) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index b7de3cd026b..766eb3420e0 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -163,6 +163,7 @@ class ServerArgs: enable_deepep_moe: bool = False deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto" ep_num_redundant_experts: int = 0 + enable_eplb: bool = False enable_torch_compile: bool = False torch_compile_max_bs: int = 32 cuda_graph_max_bs: Optional[int] = None @@ -420,8 +421,8 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.tokenizer_mode, choices=["auto", "slow"], help="Tokenizer mode. 'auto' will use the fast " - "tokenizer if available, and 'slow' will " - "always use the slow tokenizer.", + "tokenizer if available, and 'slow' will " + "always use the slow tokenizer.", ) parser.add_argument( "--skip-tokenizer-init", @@ -445,21 +446,21 @@ def add_cli_args(parser: argparse.ArgumentParser): "remote", ], help="The format of the model weights to load. " - '"auto" will try to load the weights in the safetensors format ' - "and fall back to the pytorch bin format if safetensors format " - "is not available. " - '"pt" will load the weights in the pytorch bin format. ' - '"safetensors" will load the weights in the safetensors format. ' - '"npcache" will load the weights in pytorch format and store ' - "a numpy cache to speed up the loading. " - '"dummy" will initialize the weights with random values, ' - "which is mainly for profiling." - '"gguf" will load the weights in the gguf format. ' - '"bitsandbytes" will load the weights using bitsandbytes ' - "quantization." - '"layered" loads weights layer by layer so that one can quantize a ' - "layer before loading another to make the peak memory envelope " - "smaller.", + '"auto" will try to load the weights in the safetensors format ' + "and fall back to the pytorch bin format if safetensors format " + "is not available. " + '"pt" will load the weights in the pytorch bin format. ' + '"safetensors" will load the weights in the safetensors format. ' + '"npcache" will load the weights in pytorch format and store ' + "a numpy cache to speed up the loading. " + '"dummy" will initialize the weights with random values, ' + "which is mainly for profiling." + '"gguf" will load the weights in the gguf format. ' + '"bitsandbytes" will load the weights using bitsandbytes ' + "quantization." + '"layered" loads weights layer by layer so that one can quantize a ' + "layer before loading another to make the peak memory envelope " + "smaller.", ) parser.add_argument( "--trust-remote-code", @@ -472,13 +473,13 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.dtype, choices=["auto", "half", "float16", "bfloat16", "float", "float32"], help="Data type for model weights and activations.\n\n" - '* "auto" will use FP16 precision for FP32 and FP16 models, and ' - "BF16 precision for BF16 models.\n" - '* "half" for FP16. Recommended for AWQ quantization.\n' - '* "float16" is the same as "half".\n' - '* "bfloat16" for a balance between precision and range.\n' - '* "float" is shorthand for FP32 precision.\n' - '* "float32" for FP32 precision.', + '* "auto" will use FP16 precision for FP32 and FP16 models, and ' + "BF16 precision for BF16 models.\n" + '* "half" for FP16. Recommended for AWQ quantization.\n' + '* "float16" is the same as "half".\n' + '* "bfloat16" for a balance between precision and range.\n' + '* "float" is shorthand for FP32 precision.\n' + '* "float32" for FP32 precision.', ) parser.add_argument( "--kv-cache-dtype", @@ -513,9 +514,9 @@ def add_cli_args(parser: argparse.ArgumentParser): type=nullable_str, default=None, help="Path to the JSON file containing the KV cache " - "scaling factors. This should generally be supplied, when " - "KV cache dtype is FP8. Otherwise, KV cache scaling factors " - "default to 1.0, which may cause accuracy issues. ", + "scaling factors. This should generally be supplied, when " + "KV cache dtype is FP8. Otherwise, KV cache scaling factors " + "default to 1.0, which may cause accuracy issues. ", ) parser.add_argument( "--context-length", @@ -557,8 +558,8 @@ def add_cli_args(parser: argparse.ArgumentParser): type=str, default=None, help="The specific model version to use. It can be a branch " - "name, a tag name, or a commit id. If unspecified, will use " - "the default version.", + "name, a tag name, or a commit id. If unspecified, will use " + "the default version.", ) # Memory and scheduling parser.add_argument( @@ -578,7 +579,7 @@ def add_cli_args(parser: argparse.ArgumentParser): type=int, default=ServerArgs.max_total_tokens, help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. " - "This option is typically used for development and debugging purposes.", + "This option is typically used for development and debugging purposes.", ) parser.add_argument( "--chunked-prefill-size", @@ -1043,7 +1044,7 @@ def add_cli_args(parser: argparse.ArgumentParser): "--triton-attention-reduce-in-fp32", action="store_true", help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16." - "This only affects Triton attention kernels.", + "This only affects Triton attention kernels.", ) parser.add_argument( "--triton-attention-num-kv-splits", @@ -1056,8 +1057,8 @@ def add_cli_args(parser: argparse.ArgumentParser): type=int, default=ServerArgs.num_continuous_decode_steps, help="Run multiple continuous decoding steps to reduce scheduling overhead. " - "This can potentially increase throughput but may also increase time-to-first-token latency. " - "The default value is 1, meaning only run one decoding step at a time.", + "This can potentially increase throughput but may also increase time-to-first-token latency. " + "The default value is 1, meaning only run one decoding step at a time.", ) parser.add_argument( "--delete-ckpt-after-loading", @@ -1109,6 +1110,11 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.ep_num_redundant_experts, help="Allocate this number of redundant experts in expert parallel.", ) + parser.add_argument( + "--enable-eplb", + action="store_true", + help="Enable EPLB algorithm", + ) parser.add_argument( "--deepep-mode", type=str, @@ -1122,7 +1128,7 @@ def add_cli_args(parser: argparse.ArgumentParser): type=int, default=0, help="The number of shared_experts need to be replica to fuse with normal experts in deepseek v3/r1 " - "we use tp_size by default.", + "we use tp_size by default.", ) parser.add_argument( "--disable-shared-experts-fusion", @@ -1136,7 +1142,7 @@ def add_cli_args(parser: argparse.ArgumentParser): type=str, required=False, help="Specify custom warmup functions (csv) to run before server starts eg. --warmups=warmup_name1,warmup_name2 " - "will run the functions `warmup_name1` and `warmup_name2` specified in warmup.py before the server starts listening for requests", + "will run the functions `warmup_name1` and `warmup_name2` specified in warmup.py before the server starts listening for requests", ) # Debug tensor dumps From 981a6c45fb25405f9364be105b849152718cd0d2 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 18:04:26 +0800 Subject: [PATCH 0386/1089] more --- python/sglang/srt/managers/eplb_manager.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index d29b21b7602..4cf3e1acd5d 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -12,6 +12,17 @@ class EPLBManager: + @staticmethod + def init_new(server_args: ServerArgs): + if server_args.enable_eplb: + return _EPLBManagerReal(server_args) + else: + return _EPLBManagerNoop() + + def compute_expert_location_metadata(self) -> ExpertLocationMetadata: + return TODO + +class _EPLBManagerReal(EPLBManager): def __init__(self, server_args: ServerArgs): self._server_args = server_args self._expert_distribution_storage = ExpertDistributionStorage() @@ -25,10 +36,14 @@ async def rebalance_experts(self): def compute_expert_location_metadata(self): logical_count = self._expert_distribution_storage.get_last_snapshot() if logical_count is None: - return TODO_default + return super().compute_expert_location_metadata() return _compute_expert_location_metadata_raw(self._server_args, logical_count) +class _EPLBManagerNoop(EPLBManager): + pass + + # TODO maybe move to ExpertLocationMetadata static method? def _compute_expert_location_metadata_raw(server_args: ServerArgs, logical_count: torch.Tensor): model_config = ModelConfig.from_server_args(server_args) From 1baf771f71d4842a75148dfc970847647da27a7a Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 18:04:50 +0800 Subject: [PATCH 0387/1089] more --- python/sglang/srt/entrypoints/engine.py | 2 +- python/sglang/srt/managers/eplb_manager.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index b518e778efd..3c4159dd897 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -502,7 +502,7 @@ def _launch_subprocesses( server_args.model_path, server_args.tokenizer_path ) - eplb_manager = EPLBManager(server_args) + eplb_manager = EPLBManager.init_new(server_args) expert_location_metadata = eplb_manager.compute_expert_location_metadata() scheduler_procs = [] diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index 4cf3e1acd5d..506d365c573 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -20,7 +20,7 @@ def init_new(server_args: ServerArgs): return _EPLBManagerNoop() def compute_expert_location_metadata(self) -> ExpertLocationMetadata: - return TODO + return TODO_trivial_output class _EPLBManagerReal(EPLBManager): def __init__(self, server_args: ServerArgs): From dc9dd3ebaa224f1fc9ec2cc4c1bda5d60c92ea2e Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 18:05:04 +0800 Subject: [PATCH 0388/1089] more --- python/sglang/srt/managers/eplb_manager.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index 506d365c573..b2bf65c6f03 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -19,14 +19,17 @@ def init_new(server_args: ServerArgs): else: return _EPLBManagerNoop() + def __init__(self): + self.tokenizer_manager: Optional[TokenizerManager] = None + def compute_expert_location_metadata(self) -> ExpertLocationMetadata: return TODO_trivial_output class _EPLBManagerReal(EPLBManager): def __init__(self, server_args: ServerArgs): + super().__init__() self._server_args = server_args self._expert_distribution_storage = ExpertDistributionStorage() - self.tokenizer_manager: Optional[TokenizerManager] = None async def rebalance_experts(self): TODO_may_or_may_not_save_current From 9682c56e927fdb493c047eab43cafa8f9fb933f6 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 18:08:17 +0800 Subject: [PATCH 0389/1089] rm --- python/sglang/srt/entrypoints/engine.py | 3 --- python/sglang/srt/entrypoints/http_server.py | 16 +++++----------- python/sglang/srt/managers/eplb_manager.py | 5 ----- python/sglang/srt/managers/tokenizer_manager.py | 8 -------- 4 files changed, 5 insertions(+), 27 deletions(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 3c4159dd897..78095584e71 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -302,9 +302,6 @@ def get_server_info(self): "version": __version__, } - def rebalance_experts(self): - self.tokenizer_manager.rebalance_experts() - def init_weights_update_group( self, master_address: str, diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 71c30d38370..672b60fe63a 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -313,7 +313,7 @@ async def flush_cache(): _global_state.tokenizer_manager.flush_cache() return Response( content="Cache flushed.\nPlease check backend logs for more details. " - "(When there are running or waiting requests, the operation will not be performed.)\n", + "(When there are running or waiting requests, the operation will not be performed.)\n", status_code=200, ) @@ -370,12 +370,6 @@ async def dump_expert_distribution_record_async(): return ORJSONResponse(content, status_code=200) -@app.post("/rebalance_experts") -async def rebalance_experts(): - await _global_state.tokenizer_manager.rebalance_experts() - return ORJSONResponse({}, status_code=200) - - @app.post("/update_weights_from_disk") async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: Request): """Update the weights from disk inplace without re-launching the server.""" @@ -640,10 +634,10 @@ async def vertex_generate(vertex_req: VertexGenerateReqInput, raw_request: Reque ] break image_data = [ - instance.get("image_data") - for instance in vertex_req.instances - if instance.get("image_data") is not None - ] or None + instance.get("image_data") + for instance in vertex_req.instances + if instance.get("image_data") is not None + ] or None req = GenerateReqInput( **inputs, image_data=image_data, diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index b2bf65c6f03..649d6fde5fe 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -31,11 +31,6 @@ def __init__(self, server_args: ServerArgs): self._server_args = server_args self._expert_distribution_storage = ExpertDistributionStorage() - async def rebalance_experts(self): - TODO_may_or_may_not_save_current - expert_location_metadata = self.compute_expert_location_metadata() - await self.tokenizer_manager.update_expert_location_metadata(expert_location_metadata) - def compute_expert_location_metadata(self): logical_count = self._expert_distribution_storage.get_last_snapshot() if logical_count is None: diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 41386f55be1..199ae66f0a0 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -709,14 +709,6 @@ async def _wait_for_model_update_from_disk( all_paused_requests = [r.num_paused_requests for r in result] return all_success, all_message, all_paused_requests - async def rebalance_experts(self): - self.auto_create_handle_loop() - await self.eplb_manager.rebalance_experts() - - async def update_expert_location_metadata(self, expert_location_metadata: ExpertLocationMetadata): - self.auto_create_handle_loop() - TODO - async def init_weights_update_group( self, obj: InitWeightsUpdateGroupReqInput, From 8979514ebae3d459c6a4f473b7d16a8d159be51f Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 18:08:44 +0800 Subject: [PATCH 0390/1089] fmt --- python/sglang/srt/entrypoints/http_server.py | 10 +-- python/sglang/srt/managers/eplb_manager.py | 37 +++++++--- python/sglang/srt/managers/expert_location.py | 11 +-- .../sglang/srt/managers/tokenizer_manager.py | 5 +- python/sglang/srt/models/deepseek_v2.py | 5 +- python/sglang/srt/models/qwen2_moe.py | 12 ++-- python/sglang/srt/server_args.py | 70 +++++++++---------- 7 files changed, 88 insertions(+), 62 deletions(-) diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 672b60fe63a..f43afec5423 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -313,7 +313,7 @@ async def flush_cache(): _global_state.tokenizer_manager.flush_cache() return Response( content="Cache flushed.\nPlease check backend logs for more details. " - "(When there are running or waiting requests, the operation will not be performed.)\n", + "(When there are running or waiting requests, the operation will not be performed.)\n", status_code=200, ) @@ -634,10 +634,10 @@ async def vertex_generate(vertex_req: VertexGenerateReqInput, raw_request: Reque ] break image_data = [ - instance.get("image_data") - for instance in vertex_req.instances - if instance.get("image_data") is not None - ] or None + instance.get("image_data") + for instance in vertex_req.instances + if instance.get("image_data") is not None + ] or None req = GenerateReqInput( **inputs, image_data=image_data, diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index 649d6fde5fe..5ba2c9a7ad5 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -1,10 +1,14 @@ from typing import TYPE_CHECKING, Optional import torch + from sglang.srt.configs.model_config import ModelConfig from sglang.srt.managers import deepseek_eplb from sglang.srt.managers.expert_distribution_storage import ExpertDistributionStorage -from sglang.srt.managers.expert_location import ExpertLocationMetadata, ModelConfigForExpertLocation +from sglang.srt.managers.expert_location import ( + ExpertLocationMetadata, + ModelConfigForExpertLocation, +) from sglang.srt.server_args import ServerArgs if TYPE_CHECKING: @@ -25,6 +29,7 @@ def __init__(self): def compute_expert_location_metadata(self) -> ExpertLocationMetadata: return TODO_trivial_output + class _EPLBManagerReal(EPLBManager): def __init__(self, server_args: ServerArgs): super().__init__() @@ -43,15 +48,21 @@ class _EPLBManagerNoop(EPLBManager): # TODO maybe move to ExpertLocationMetadata static method? -def _compute_expert_location_metadata_raw(server_args: ServerArgs, logical_count: torch.Tensor): +def _compute_expert_location_metadata_raw( + server_args: ServerArgs, logical_count: torch.Tensor +): model_config = ModelConfig.from_server_args(server_args) - model_config_for_expert_location = ModelConfigForExpertLocation.from_model_config(model_config) - physical_to_logical_map, logical_to_physical_map, expert_count = deepseek_eplb.rebalance_experts( - weight=logical_count, - num_replicas=num_physical_experts, - num_groups=model_config_for_expert_location.num_groups, - num_nodes=server_args.nnodes, - num_gpus=world_size, + model_config_for_expert_location = ModelConfigForExpertLocation.from_model_config( + model_config + ) + physical_to_logical_map, logical_to_physical_map, expert_count = ( + deepseek_eplb.rebalance_experts( + weight=logical_count, + num_replicas=num_physical_experts, + num_groups=model_config_for_expert_location.num_groups, + num_nodes=server_args.nnodes, + num_gpus=world_size, + ) ) return ExpertLocationMetadata( num_layers=model_config_for_expert_location.num_layers, @@ -59,9 +70,13 @@ def _compute_expert_location_metadata_raw(server_args: ServerArgs, logical_count num_logical_experts=model_config_for_expert_location.num_logical_experts, physical_to_logical_map=physical_to_logical_map, logical_to_all_physical_map=logical_to_physical_map, - logical_to_rank_dispatch_physical_map=_compute_logical_to_rank_dispatch_physical_map(logical_to_physical_map), + logical_to_rank_dispatch_physical_map=_compute_logical_to_rank_dispatch_physical_map( + logical_to_physical_map + ), ) -def _compute_logical_to_rank_dispatch_physical_map(logical_to_physical_map: torch.Tensor): +def _compute_logical_to_rank_dispatch_physical_map( + logical_to_physical_map: torch.Tensor, +): return TODO diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 5db25a17ba2..3b944a1053a 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -2,6 +2,7 @@ from typing import List, Optional import torch + from sglang.srt.configs.model_config import ModelConfig from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.managers.schedule_batch import global_server_args_dict @@ -24,7 +25,7 @@ class ExpertLocationMetadata: def init_trivial(num_layers: int, num_logical_experts: int): """Trivial location - logical expert i corresponds to physical expert i""" num_physical_experts = ( - num_logical_experts + global_server_args_dict["ep_num_redundant_experts"] + num_logical_experts + global_server_args_dict["ep_num_redundant_experts"] ) world_size = get_tensor_model_parallel_world_size() assert num_physical_experts % world_size == 0 @@ -37,7 +38,7 @@ def init_trivial(num_layers: int, num_logical_experts: int): physical_to_logical_map=torch.arange(0, num_physical_experts).repeat( num_layers, 1 ) - % num_logical_experts, + % num_logical_experts, # Throw away the redundant experts here - highly inefficient, but we do not care since we will # use EPLB distribution logic logical_to_all_physical_map=torch.arange(0, num_logical_experts).repeat( @@ -57,7 +58,7 @@ def physical_to_local_physical(self, global_physical_expert_index: int): return global_physical_expert_index % self.num_local_physical_experts def logical_to_all_physical( - self, layer_id: int, logical_expert_id: int + self, layer_id: int, logical_expert_id: int ) -> List[int]: return [ physical_expert_id @@ -82,6 +83,8 @@ def init_dummy(): def from_model_config(model_config: ModelConfig): model_class, _ = get_model_architecture(model_config) if hasattr(model_class, "get_model_config_for_expert_location"): - return model_class.get_model_config_for_expert_location(model_config.hf_config) + return model_class.get_model_config_for_expert_location( + model_config.hf_config + ) else: return ModelConfigForExpertLocation.init_dummy() diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 199ae66f0a0..95ee879bbf0 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -45,6 +45,7 @@ import zmq import zmq.asyncio from fastapi import BackgroundTasks + from sglang.srt.aio_rwlock import RWLock from sglang.srt.configs.model_config import ModelConfig from sglang.srt.disaggregation.conn import KVBootstrapServer @@ -966,8 +967,8 @@ def _handle_batch_output( elif isinstance(recv_obj, BatchTokenIDOut): if self.server_args.stream_output and state.obj.stream: output_token_ids = recv_obj.output_ids[i][ - state.last_output_offset: - ] + state.last_output_offset : + ] state.last_output_offset = len(recv_obj.output_ids[i]) else: output_token_ids = recv_obj.output_ids[i] diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index e9ff0548e48..2f99e657b6d 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -69,7 +69,10 @@ VocabParallelEmbedding, ) from sglang.srt.managers.expert_distribution import expert_distribution_recorder -from sglang.srt.managers.expert_location import ExpertLocationMetadata, ModelConfigForExpertLocation +from sglang.srt.managers.expert_location import ( + ExpertLocationMetadata, + ModelConfigForExpertLocation, +) from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_loader.weight_utils import default_weight_loader diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 47b9850add6..4cce4f0aca3 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -20,6 +20,9 @@ import torch import torch.nn.functional as F +from torch import nn +from transformers import PretrainedConfig + from sglang.srt.distributed import ( get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, @@ -42,12 +45,13 @@ VocabParallelEmbedding, ) from sglang.srt.managers.expert_distribution import expert_distribution_recorder -from sglang.srt.managers.expert_location import ExpertLocationMetadata, ModelConfigForExpertLocation +from sglang.srt.managers.expert_location import ( + ExpertLocationMetadata, + ModelConfigForExpertLocation, +) from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import add_prefix -from torch import nn -from transformers import PretrainedConfig class Qwen2MoeMLP(nn.Module): @@ -194,7 +198,7 @@ def __init__( self.head_dim = hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim - self.scaling = self.head_dim ** -0.5 + self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 766eb3420e0..3e6810367e6 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -421,8 +421,8 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.tokenizer_mode, choices=["auto", "slow"], help="Tokenizer mode. 'auto' will use the fast " - "tokenizer if available, and 'slow' will " - "always use the slow tokenizer.", + "tokenizer if available, and 'slow' will " + "always use the slow tokenizer.", ) parser.add_argument( "--skip-tokenizer-init", @@ -446,21 +446,21 @@ def add_cli_args(parser: argparse.ArgumentParser): "remote", ], help="The format of the model weights to load. " - '"auto" will try to load the weights in the safetensors format ' - "and fall back to the pytorch bin format if safetensors format " - "is not available. " - '"pt" will load the weights in the pytorch bin format. ' - '"safetensors" will load the weights in the safetensors format. ' - '"npcache" will load the weights in pytorch format and store ' - "a numpy cache to speed up the loading. " - '"dummy" will initialize the weights with random values, ' - "which is mainly for profiling." - '"gguf" will load the weights in the gguf format. ' - '"bitsandbytes" will load the weights using bitsandbytes ' - "quantization." - '"layered" loads weights layer by layer so that one can quantize a ' - "layer before loading another to make the peak memory envelope " - "smaller.", + '"auto" will try to load the weights in the safetensors format ' + "and fall back to the pytorch bin format if safetensors format " + "is not available. " + '"pt" will load the weights in the pytorch bin format. ' + '"safetensors" will load the weights in the safetensors format. ' + '"npcache" will load the weights in pytorch format and store ' + "a numpy cache to speed up the loading. " + '"dummy" will initialize the weights with random values, ' + "which is mainly for profiling." + '"gguf" will load the weights in the gguf format. ' + '"bitsandbytes" will load the weights using bitsandbytes ' + "quantization." + '"layered" loads weights layer by layer so that one can quantize a ' + "layer before loading another to make the peak memory envelope " + "smaller.", ) parser.add_argument( "--trust-remote-code", @@ -473,13 +473,13 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.dtype, choices=["auto", "half", "float16", "bfloat16", "float", "float32"], help="Data type for model weights and activations.\n\n" - '* "auto" will use FP16 precision for FP32 and FP16 models, and ' - "BF16 precision for BF16 models.\n" - '* "half" for FP16. Recommended for AWQ quantization.\n' - '* "float16" is the same as "half".\n' - '* "bfloat16" for a balance between precision and range.\n' - '* "float" is shorthand for FP32 precision.\n' - '* "float32" for FP32 precision.', + '* "auto" will use FP16 precision for FP32 and FP16 models, and ' + "BF16 precision for BF16 models.\n" + '* "half" for FP16. Recommended for AWQ quantization.\n' + '* "float16" is the same as "half".\n' + '* "bfloat16" for a balance between precision and range.\n' + '* "float" is shorthand for FP32 precision.\n' + '* "float32" for FP32 precision.', ) parser.add_argument( "--kv-cache-dtype", @@ -514,9 +514,9 @@ def add_cli_args(parser: argparse.ArgumentParser): type=nullable_str, default=None, help="Path to the JSON file containing the KV cache " - "scaling factors. This should generally be supplied, when " - "KV cache dtype is FP8. Otherwise, KV cache scaling factors " - "default to 1.0, which may cause accuracy issues. ", + "scaling factors. This should generally be supplied, when " + "KV cache dtype is FP8. Otherwise, KV cache scaling factors " + "default to 1.0, which may cause accuracy issues. ", ) parser.add_argument( "--context-length", @@ -558,8 +558,8 @@ def add_cli_args(parser: argparse.ArgumentParser): type=str, default=None, help="The specific model version to use. It can be a branch " - "name, a tag name, or a commit id. If unspecified, will use " - "the default version.", + "name, a tag name, or a commit id. If unspecified, will use " + "the default version.", ) # Memory and scheduling parser.add_argument( @@ -579,7 +579,7 @@ def add_cli_args(parser: argparse.ArgumentParser): type=int, default=ServerArgs.max_total_tokens, help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. " - "This option is typically used for development and debugging purposes.", + "This option is typically used for development and debugging purposes.", ) parser.add_argument( "--chunked-prefill-size", @@ -1044,7 +1044,7 @@ def add_cli_args(parser: argparse.ArgumentParser): "--triton-attention-reduce-in-fp32", action="store_true", help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16." - "This only affects Triton attention kernels.", + "This only affects Triton attention kernels.", ) parser.add_argument( "--triton-attention-num-kv-splits", @@ -1057,8 +1057,8 @@ def add_cli_args(parser: argparse.ArgumentParser): type=int, default=ServerArgs.num_continuous_decode_steps, help="Run multiple continuous decoding steps to reduce scheduling overhead. " - "This can potentially increase throughput but may also increase time-to-first-token latency. " - "The default value is 1, meaning only run one decoding step at a time.", + "This can potentially increase throughput but may also increase time-to-first-token latency. " + "The default value is 1, meaning only run one decoding step at a time.", ) parser.add_argument( "--delete-ckpt-after-loading", @@ -1128,7 +1128,7 @@ def add_cli_args(parser: argparse.ArgumentParser): type=int, default=0, help="The number of shared_experts need to be replica to fuse with normal experts in deepseek v3/r1 " - "we use tp_size by default.", + "we use tp_size by default.", ) parser.add_argument( "--disable-shared-experts-fusion", @@ -1142,7 +1142,7 @@ def add_cli_args(parser: argparse.ArgumentParser): type=str, required=False, help="Specify custom warmup functions (csv) to run before server starts eg. --warmups=warmup_name1,warmup_name2 " - "will run the functions `warmup_name1` and `warmup_name2` specified in warmup.py before the server starts listening for requests", + "will run the functions `warmup_name1` and `warmup_name2` specified in warmup.py before the server starts listening for requests", ) # Debug tensor dumps From 8775c9ead766d84692e05f4ead0ed99b32c0184e Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 18:09:47 +0800 Subject: [PATCH 0391/1089] more --- python/sglang/srt/managers/expert_location.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 3b944a1053a..27f3f9ef7d1 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -4,7 +4,6 @@ import torch from sglang.srt.configs.model_config import ModelConfig -from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_loader import get_model_architecture @@ -27,7 +26,7 @@ def init_trivial(num_layers: int, num_logical_experts: int): num_physical_experts = ( num_logical_experts + global_server_args_dict["ep_num_redundant_experts"] ) - world_size = get_tensor_model_parallel_world_size() + world_size = server_args.tp_size assert num_physical_experts % world_size == 0 num_local_physical_experts = num_physical_experts // world_size From 5a5fdb8290f8d60c02d3670e9b06fd5d718fdfb3 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 18:10:54 +0800 Subject: [PATCH 0392/1089] more --- python/sglang/srt/managers/eplb_manager.py | 8 +++++++- python/sglang/srt/managers/expert_location.py | 7 ------- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index 5ba2c9a7ad5..1cfc823a3b6 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -55,6 +55,12 @@ def _compute_expert_location_metadata_raw( model_config_for_expert_location = ModelConfigForExpertLocation.from_model_config( model_config ) + + num_physical_experts = model_config_for_expert_location.num_logical_experts + server_args.ep_num_redundant_experts + world_size = server_args.tp_size + assert num_physical_experts % world_size == 0 + num_local_physical_experts = num_physical_experts // world_size + physical_to_logical_map, logical_to_physical_map, expert_count = ( deepseek_eplb.rebalance_experts( weight=logical_count, @@ -66,7 +72,7 @@ def _compute_expert_location_metadata_raw( ) return ExpertLocationMetadata( num_layers=model_config_for_expert_location.num_layers, - num_local_physical_experts=TODO, + num_local_physical_experts=num_local_physical_experts, num_logical_experts=model_config_for_expert_location.num_logical_experts, physical_to_logical_map=physical_to_logical_map, logical_to_all_physical_map=logical_to_physical_map, diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 27f3f9ef7d1..d9156b3d923 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -23,13 +23,6 @@ class ExpertLocationMetadata: @staticmethod def init_trivial(num_layers: int, num_logical_experts: int): """Trivial location - logical expert i corresponds to physical expert i""" - num_physical_experts = ( - num_logical_experts + global_server_args_dict["ep_num_redundant_experts"] - ) - world_size = server_args.tp_size - assert num_physical_experts % world_size == 0 - num_local_physical_experts = num_physical_experts // world_size - return ExpertLocationMetadata( num_layers=num_layers, num_logical_experts=num_logical_experts, From 0597c10a724a4db2d39001c74537103aec9b9a2d Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 18:11:47 +0800 Subject: [PATCH 0393/1089] more --- python/sglang/srt/managers/eplb_manager.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index 1cfc823a3b6..b919b2443cd 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -49,7 +49,7 @@ class _EPLBManagerNoop(EPLBManager): # TODO maybe move to ExpertLocationMetadata static method? def _compute_expert_location_metadata_raw( - server_args: ServerArgs, logical_count: torch.Tensor + server_args: ServerArgs, logical_count: torch.Tensor ): model_config = ModelConfig.from_server_args(server_args) model_config_for_expert_location = ModelConfigForExpertLocation.from_model_config( @@ -57,6 +57,7 @@ def _compute_expert_location_metadata_raw( ) num_physical_experts = model_config_for_expert_location.num_logical_experts + server_args.ep_num_redundant_experts + # TODO consider case when DP attention is disabled and DP > 1 world_size = server_args.tp_size assert num_physical_experts % world_size == 0 num_local_physical_experts = num_physical_experts // world_size @@ -83,6 +84,6 @@ def _compute_expert_location_metadata_raw( def _compute_logical_to_rank_dispatch_physical_map( - logical_to_physical_map: torch.Tensor, + logical_to_physical_map: torch.Tensor, ): return TODO From 7350e8662a5d14dc61a7fafb086a6e88d33f2e56 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 18:11:53 +0800 Subject: [PATCH 0394/1089] more --- python/sglang/srt/managers/eplb_manager.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index b919b2443cd..ea6eb2369a2 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -1,7 +1,6 @@ from typing import TYPE_CHECKING, Optional import torch - from sglang.srt.configs.model_config import ModelConfig from sglang.srt.managers import deepseek_eplb from sglang.srt.managers.expert_distribution_storage import ExpertDistributionStorage @@ -47,7 +46,6 @@ class _EPLBManagerNoop(EPLBManager): pass -# TODO maybe move to ExpertLocationMetadata static method? def _compute_expert_location_metadata_raw( server_args: ServerArgs, logical_count: torch.Tensor ): From bde8ed10f503fcc9605fc9dc0f02cdbf1d1c4818 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 18:13:42 +0800 Subject: [PATCH 0395/1089] more --- python/sglang/srt/managers/eplb_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index ea6eb2369a2..3d424be81aa 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -26,7 +26,7 @@ def __init__(self): self.tokenizer_manager: Optional[TokenizerManager] = None def compute_expert_location_metadata(self) -> ExpertLocationMetadata: - return TODO_trivial_output + return ExpertLocationMetadata.init_trivial(TODO) class _EPLBManagerReal(EPLBManager): From bfbb6abdd1dd2c186d349a29d56a6d9409cf7a1c Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 18:21:35 +0800 Subject: [PATCH 0396/1089] more --- python/sglang/srt/model_loader/loader.py | 25 +++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py index 4e42ee897c0..cf269be78e5 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py @@ -374,20 +374,23 @@ def load_model( self.load_config, ) - model.load_weights(self._get_all_weights(model_config, model)) + self.load_weights_and_postprocess(model_config, model, target_device) - for _, module in model.named_modules(): - quant_method = getattr(module, "quant_method", None) - if quant_method is not None: - # When quant methods need to process weights after loading - # (for repacking, quantizing, etc), they expect parameters - # to be on the global target device. This scope is for the - # case where cpu offloading is used, where we will move the - # parameters onto device for processing and back off after. - with device_loading_context(module, target_device): - quant_method.process_weights_after_loading(module) return model.eval() + def load_weights_and_postprocess(self, model_config, model, target_device): + model.load_weights(self._get_all_weights(model_config, model)) + + for _, module in model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + # When quant methods need to process weights after loading + # (for repacking, quantizing, etc), they expect parameters + # to be on the global target device. This scope is for the + # case where cpu offloading is used, where we will move the + # parameters onto device for processing and back off after. + with device_loading_context(module, target_device): + quant_method.process_weights_after_loading(module) class LayeredModelLoader(DefaultModelLoader): """Model loader that loads weights layer by layer so that one can quantize a From dc7c425365ee7554d4995c8514d2273c38beb2fc Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 18:22:30 +0800 Subject: [PATCH 0397/1089] more --- python/sglang/srt/model_loader/loader.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py index cf269be78e5..1f245c7e081 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py @@ -374,12 +374,13 @@ def load_model( self.load_config, ) - self.load_weights_and_postprocess(model_config, model, target_device) + self.load_weights_and_postprocess(model, self._get_all_weights(model_config, model), target_device) return model.eval() - def load_weights_and_postprocess(self, model_config, model, target_device): - model.load_weights(self._get_all_weights(model_config, model)) + @staticmethod + def load_weights_and_postprocess(model, weights, target_device): + model.load_weights(weights) for _, module in model.named_modules(): quant_method = getattr(module, "quant_method", None) From 661de4a1693cb1ba81bfe605b30c2247f72c67a0 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 18:23:17 +0800 Subject: [PATCH 0398/1089] more --- python/sglang/srt/model_executor/model_runner.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 7c9ec61da99..5663cafe1ab 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -489,12 +489,7 @@ def get_weight_iter(config): return iter def model_load_weights(model, iter): - model.load_weights(iter) - for _, module in self.model.named_modules(): - quant_method = getattr(module, "quant_method", None) - if quant_method is not None: - with device_loading_context(module, target_device): - quant_method.process_weights_after_loading(module) + DefaultModelLoader.load_weights_and_postprocess(model, iter, target_device) return model with set_default_torch_dtype(self.model_config.dtype): From 621f10239217f28dbc7fe36bf3d4597a12cf9e1b Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 18:23:41 +0800 Subject: [PATCH 0399/1089] fmt --- python/sglang/srt/model_loader/loader.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py index 1f245c7e081..af0105c7ecf 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py @@ -374,7 +374,9 @@ def load_model( self.load_config, ) - self.load_weights_and_postprocess(model, self._get_all_weights(model_config, model), target_device) + self.load_weights_and_postprocess( + model, self._get_all_weights(model_config, model), target_device + ) return model.eval() @@ -393,6 +395,7 @@ def load_weights_and_postprocess(model, weights, target_device): with device_loading_context(module, target_device): quant_method.process_weights_after_loading(module) + class LayeredModelLoader(DefaultModelLoader): """Model loader that loads weights layer by layer so that one can quantize a layer before loading another to make the peak memory envelope smaller.""" From 72fc8be6752faca39ccb543ef3f09293d0186a47 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 18:27:17 +0800 Subject: [PATCH 0400/1089] more --- python/sglang/srt/managers/eplb_manager.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index 3d424be81aa..f55df78e5a4 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -84,4 +84,5 @@ def _compute_expert_location_metadata_raw( def _compute_logical_to_rank_dispatch_physical_map( logical_to_physical_map: torch.Tensor, ): + # TODO maybe improve this algorithm (e.g. ensure it is really balanced) return TODO From c1f1516e6281acf1852c41562ba399bec4b97252 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 18:27:36 +0800 Subject: [PATCH 0401/1089] more --- python/sglang/srt/managers/eplb_manager.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index f55df78e5a4..eef748518e3 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -60,7 +60,7 @@ def _compute_expert_location_metadata_raw( assert num_physical_experts % world_size == 0 num_local_physical_experts = num_physical_experts // world_size - physical_to_logical_map, logical_to_physical_map, expert_count = ( + physical_to_logical_map, logical_to_all_physical_map, expert_count = ( deepseek_eplb.rebalance_experts( weight=logical_count, num_replicas=num_physical_experts, @@ -74,15 +74,15 @@ def _compute_expert_location_metadata_raw( num_local_physical_experts=num_local_physical_experts, num_logical_experts=model_config_for_expert_location.num_logical_experts, physical_to_logical_map=physical_to_logical_map, - logical_to_all_physical_map=logical_to_physical_map, + logical_to_all_physical_map=logical_to_all_physical_map, logical_to_rank_dispatch_physical_map=_compute_logical_to_rank_dispatch_physical_map( - logical_to_physical_map + logical_to_all_physical_map ), ) def _compute_logical_to_rank_dispatch_physical_map( - logical_to_physical_map: torch.Tensor, + logical_to_all_physical_map: torch.Tensor, ): # TODO maybe improve this algorithm (e.g. ensure it is really balanced) return TODO From 24ce30e82e7ab97462d65f55c87bf79e65c1a81d Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 18:30:32 +0800 Subject: [PATCH 0402/1089] more --- python/sglang/srt/managers/eplb_manager.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index eef748518e3..05c02a325cb 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -85,4 +85,8 @@ def _compute_logical_to_rank_dispatch_physical_map( logical_to_all_physical_map: torch.Tensor, ): # TODO maybe improve this algorithm (e.g. ensure it is really balanced) + + num_layers, num_logical_experts, _ = logical_to_all_physical_map.shape + logical_to_rank_dispatch_physical_map = torch.zeros((num_gpus, num_layers, num_logical_experts)) + return TODO From 68b8c2a7af7571ca660218b8906170595f0768cd Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 18:30:50 +0800 Subject: [PATCH 0403/1089] more --- python/sglang/srt/managers/eplb_manager.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index 05c02a325cb..af3c585709e 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -85,6 +85,7 @@ def _compute_logical_to_rank_dispatch_physical_map( logical_to_all_physical_map: torch.Tensor, ): # TODO maybe improve this algorithm (e.g. ensure it is really balanced) + # This is rarely called, so we use for loops for maximum clarity num_layers, num_logical_experts, _ = logical_to_all_physical_map.shape logical_to_rank_dispatch_physical_map = torch.zeros((num_gpus, num_layers, num_logical_experts)) From 8a66fe6c0dbf83e19975a909fb248f569260586e Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 18:31:30 +0800 Subject: [PATCH 0404/1089] more --- python/sglang/srt/managers/eplb_manager.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index af3c585709e..5c0961808d2 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -76,13 +76,14 @@ def _compute_expert_location_metadata_raw( physical_to_logical_map=physical_to_logical_map, logical_to_all_physical_map=logical_to_all_physical_map, logical_to_rank_dispatch_physical_map=_compute_logical_to_rank_dispatch_physical_map( - logical_to_all_physical_map + logical_to_all_physical_map, num_gpus=world_size, ), ) def _compute_logical_to_rank_dispatch_physical_map( logical_to_all_physical_map: torch.Tensor, + num_gpus: int, ): # TODO maybe improve this algorithm (e.g. ensure it is really balanced) # This is rarely called, so we use for loops for maximum clarity @@ -90,4 +91,9 @@ def _compute_logical_to_rank_dispatch_physical_map( num_layers, num_logical_experts, _ = logical_to_all_physical_map.shape logical_to_rank_dispatch_physical_map = torch.zeros((num_gpus, num_layers, num_logical_experts)) - return TODO + for layer_id in range(num_layers): + for logical_expert_id in range(num_logical_experts): + for gpu_id in range(num_gpus): + TODO + + return logical_to_rank_dispatch_physical_map From 676195b6cd725a636976a675fe2003608aae7a8e Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 18:32:56 +0800 Subject: [PATCH 0405/1089] more --- python/sglang/srt/managers/eplb_manager.py | 4 +++- python/sglang/srt/managers/expert_location.py | 13 +++++++++---- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index 5c0961808d2..26ada992edc 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -94,6 +94,8 @@ def _compute_logical_to_rank_dispatch_physical_map( for layer_id in range(num_layers): for logical_expert_id in range(num_logical_experts): for gpu_id in range(num_gpus): - TODO + candidate_values = ExpertLocationMetadata.logical_to_all_physical_raw(logical_to_all_physical_map, + layer_id, logical_expert_id) + logical_to_rank_dispatch_physical_map[gpu_id, layer_id, logical_expert_id] = TODO return logical_to_rank_dispatch_physical_map diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index d9156b3d923..bfdbcc257bc 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -2,7 +2,6 @@ from typing import List, Optional import torch - from sglang.srt.configs.model_config import ModelConfig from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_loader import get_model_architecture @@ -30,7 +29,7 @@ def init_trivial(num_layers: int, num_logical_experts: int): physical_to_logical_map=torch.arange(0, num_physical_experts).repeat( num_layers, 1 ) - % num_logical_experts, + % num_logical_experts, # Throw away the redundant experts here - highly inefficient, but we do not care since we will # use EPLB distribution logic logical_to_all_physical_map=torch.arange(0, num_logical_experts).repeat( @@ -50,11 +49,17 @@ def physical_to_local_physical(self, global_physical_expert_index: int): return global_physical_expert_index % self.num_local_physical_experts def logical_to_all_physical( - self, layer_id: int, logical_expert_id: int + self, layer_id: int, logical_expert_id: int + ) -> List[int]: + return self.logical_to_all_physical_raw(self.logical_to_all_physical_map, layer_id, logical_expert_id) + + @staticmethod + def logical_to_all_physical_raw( + logical_to_all_physical_map, layer_id: int, logical_expert_id: int ) -> List[int]: return [ physical_expert_id - for physical_expert_id in self.logical_to_all_physical_map[ + for physical_expert_id in logical_to_all_physical_map[ layer_id, logical_expert_id ].tolist() if physical_expert_id != -1 From 58cddd62aaafefbcf9e77bf8bc421301282de789 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 18:33:23 +0800 Subject: [PATCH 0406/1089] more --- python/sglang/srt/managers/eplb_manager.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index 26ada992edc..6716bfae399 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -1,3 +1,4 @@ +import random from typing import TYPE_CHECKING, Optional import torch @@ -88,14 +89,16 @@ def _compute_logical_to_rank_dispatch_physical_map( # TODO maybe improve this algorithm (e.g. ensure it is really balanced) # This is rarely called, so we use for loops for maximum clarity + r = random.Random() + num_layers, num_logical_experts, _ = logical_to_all_physical_map.shape logical_to_rank_dispatch_physical_map = torch.zeros((num_gpus, num_layers, num_logical_experts)) for layer_id in range(num_layers): for logical_expert_id in range(num_logical_experts): for gpu_id in range(num_gpus): - candidate_values = ExpertLocationMetadata.logical_to_all_physical_raw(logical_to_all_physical_map, - layer_id, logical_expert_id) - logical_to_rank_dispatch_physical_map[gpu_id, layer_id, logical_expert_id] = TODO + candidate_values = ExpertLocationMetadata.logical_to_all_physical_raw( + logical_to_all_physical_map, layer_id, logical_expert_id) + logical_to_rank_dispatch_physical_map[gpu_id, layer_id, logical_expert_id] = r.choice(candidate_values) return logical_to_rank_dispatch_physical_map From ca826487f8d04f4097953afb5a2fb02a29e66549 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 18:34:28 +0800 Subject: [PATCH 0407/1089] more --- python/sglang/srt/managers/expert_distribution_storage.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_distribution_storage.py b/python/sglang/srt/managers/expert_distribution_storage.py index bc857888398..a1f80cdd454 100644 --- a/python/sglang/srt/managers/expert_distribution_storage.py +++ b/python/sglang/srt/managers/expert_distribution_storage.py @@ -1,5 +1,8 @@ +from pathlib import Path + + class ExpertDistributionStorage: - def __init__(self): + def __init__(self, dir_data: Path): TODO def save_current(self): From 57d512c8e774309d830210ebba51a285780e066f Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 20:49:41 +0800 Subject: [PATCH 0408/1089] more --- python/sglang/srt/entrypoints/engine.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 78095584e71..1d2e1d6a728 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -30,7 +30,6 @@ import zmq import zmq.asyncio from PIL.Image import Image - from sglang.srt.configs.model_config import ModelConfig from sglang.srt.managers.eplb_manager import EPLBManager from sglang.srt.managers.expert_location import ExpertLocationMetadata @@ -500,7 +499,7 @@ def _launch_subprocesses( ) eplb_manager = EPLBManager.init_new(server_args) - expert_location_metadata = eplb_manager.compute_expert_location_metadata() + expert_location_metadata = _compute_initial_expert_location_metadata(server_args, eplb_manager) scheduler_procs = [] if server_args.dp_size == 1: @@ -614,3 +613,7 @@ def _launch_subprocesses( scheduler_info = scheduler_infos[0] tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"] return tokenizer_manager, scheduler_info + + +def _compute_initial_expert_location_metadata(server_args: ServerArgs, eplb_manager: EPLBManager): + return TODO From 4ae108c81fe4c2c28e890f66b2339cc763fb55da Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 20:50:50 +0800 Subject: [PATCH 0409/1089] more --- python/sglang/srt/server_args.py | 77 +++++++++++++++++--------------- 1 file changed, 42 insertions(+), 35 deletions(-) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 3e6810367e6..920889ae18b 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -163,6 +163,7 @@ class ServerArgs: enable_deepep_moe: bool = False deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto" ep_num_redundant_experts: int = 0 + init_expert_location: Optional[str] = None enable_eplb: bool = False enable_torch_compile: bool = False torch_compile_max_bs: int = 32 @@ -421,8 +422,8 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.tokenizer_mode, choices=["auto", "slow"], help="Tokenizer mode. 'auto' will use the fast " - "tokenizer if available, and 'slow' will " - "always use the slow tokenizer.", + "tokenizer if available, and 'slow' will " + "always use the slow tokenizer.", ) parser.add_argument( "--skip-tokenizer-init", @@ -446,21 +447,21 @@ def add_cli_args(parser: argparse.ArgumentParser): "remote", ], help="The format of the model weights to load. " - '"auto" will try to load the weights in the safetensors format ' - "and fall back to the pytorch bin format if safetensors format " - "is not available. " - '"pt" will load the weights in the pytorch bin format. ' - '"safetensors" will load the weights in the safetensors format. ' - '"npcache" will load the weights in pytorch format and store ' - "a numpy cache to speed up the loading. " - '"dummy" will initialize the weights with random values, ' - "which is mainly for profiling." - '"gguf" will load the weights in the gguf format. ' - '"bitsandbytes" will load the weights using bitsandbytes ' - "quantization." - '"layered" loads weights layer by layer so that one can quantize a ' - "layer before loading another to make the peak memory envelope " - "smaller.", + '"auto" will try to load the weights in the safetensors format ' + "and fall back to the pytorch bin format if safetensors format " + "is not available. " + '"pt" will load the weights in the pytorch bin format. ' + '"safetensors" will load the weights in the safetensors format. ' + '"npcache" will load the weights in pytorch format and store ' + "a numpy cache to speed up the loading. " + '"dummy" will initialize the weights with random values, ' + "which is mainly for profiling." + '"gguf" will load the weights in the gguf format. ' + '"bitsandbytes" will load the weights using bitsandbytes ' + "quantization." + '"layered" loads weights layer by layer so that one can quantize a ' + "layer before loading another to make the peak memory envelope " + "smaller.", ) parser.add_argument( "--trust-remote-code", @@ -473,13 +474,13 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.dtype, choices=["auto", "half", "float16", "bfloat16", "float", "float32"], help="Data type for model weights and activations.\n\n" - '* "auto" will use FP16 precision for FP32 and FP16 models, and ' - "BF16 precision for BF16 models.\n" - '* "half" for FP16. Recommended for AWQ quantization.\n' - '* "float16" is the same as "half".\n' - '* "bfloat16" for a balance between precision and range.\n' - '* "float" is shorthand for FP32 precision.\n' - '* "float32" for FP32 precision.', + '* "auto" will use FP16 precision for FP32 and FP16 models, and ' + "BF16 precision for BF16 models.\n" + '* "half" for FP16. Recommended for AWQ quantization.\n' + '* "float16" is the same as "half".\n' + '* "bfloat16" for a balance between precision and range.\n' + '* "float" is shorthand for FP32 precision.\n' + '* "float32" for FP32 precision.', ) parser.add_argument( "--kv-cache-dtype", @@ -514,9 +515,9 @@ def add_cli_args(parser: argparse.ArgumentParser): type=nullable_str, default=None, help="Path to the JSON file containing the KV cache " - "scaling factors. This should generally be supplied, when " - "KV cache dtype is FP8. Otherwise, KV cache scaling factors " - "default to 1.0, which may cause accuracy issues. ", + "scaling factors. This should generally be supplied, when " + "KV cache dtype is FP8. Otherwise, KV cache scaling factors " + "default to 1.0, which may cause accuracy issues. ", ) parser.add_argument( "--context-length", @@ -558,8 +559,8 @@ def add_cli_args(parser: argparse.ArgumentParser): type=str, default=None, help="The specific model version to use. It can be a branch " - "name, a tag name, or a commit id. If unspecified, will use " - "the default version.", + "name, a tag name, or a commit id. If unspecified, will use " + "the default version.", ) # Memory and scheduling parser.add_argument( @@ -579,7 +580,7 @@ def add_cli_args(parser: argparse.ArgumentParser): type=int, default=ServerArgs.max_total_tokens, help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. " - "This option is typically used for development and debugging purposes.", + "This option is typically used for development and debugging purposes.", ) parser.add_argument( "--chunked-prefill-size", @@ -1044,7 +1045,7 @@ def add_cli_args(parser: argparse.ArgumentParser): "--triton-attention-reduce-in-fp32", action="store_true", help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16." - "This only affects Triton attention kernels.", + "This only affects Triton attention kernels.", ) parser.add_argument( "--triton-attention-num-kv-splits", @@ -1057,8 +1058,8 @@ def add_cli_args(parser: argparse.ArgumentParser): type=int, default=ServerArgs.num_continuous_decode_steps, help="Run multiple continuous decoding steps to reduce scheduling overhead. " - "This can potentially increase throughput but may also increase time-to-first-token latency. " - "The default value is 1, meaning only run one decoding step at a time.", + "This can potentially increase throughput but may also increase time-to-first-token latency. " + "The default value is 1, meaning only run one decoding step at a time.", ) parser.add_argument( "--delete-ckpt-after-loading", @@ -1104,6 +1105,12 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Enabling DeepEP MoE implementation for EP MoE.", ) + parser.add_argument( + "--init-expert-location", + type=str, + default=ServerArgs.init_expert_location, + help="Initial location of EP experts.", + ) parser.add_argument( "--ep-num-redundant-experts", type=int, @@ -1128,7 +1135,7 @@ def add_cli_args(parser: argparse.ArgumentParser): type=int, default=0, help="The number of shared_experts need to be replica to fuse with normal experts in deepseek v3/r1 " - "we use tp_size by default.", + "we use tp_size by default.", ) parser.add_argument( "--disable-shared-experts-fusion", @@ -1142,7 +1149,7 @@ def add_cli_args(parser: argparse.ArgumentParser): type=str, required=False, help="Specify custom warmup functions (csv) to run before server starts eg. --warmups=warmup_name1,warmup_name2 " - "will run the functions `warmup_name1` and `warmup_name2` specified in warmup.py before the server starts listening for requests", + "will run the functions `warmup_name1` and `warmup_name2` specified in warmup.py before the server starts listening for requests", ) # Debug tensor dumps From 6b7b09c8ce1a02bfd99ae7f866c18562a162dc54 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 20:51:07 +0800 Subject: [PATCH 0410/1089] more --- python/sglang/srt/entrypoints/engine.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 1d2e1d6a728..d43f54fa5ac 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -616,4 +616,8 @@ def _launch_subprocesses( def _compute_initial_expert_location_metadata(server_args: ServerArgs, eplb_manager: EPLBManager): + if server_args.init_expert_location is not None: + return TODO + if server_args.enable_eplb: + return TODO(eplb_manager) return TODO From bf4d4b82e87695f089abbdce18320df7e03ae17f Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 20:51:25 +0800 Subject: [PATCH 0411/1089] more --- python/sglang/srt/entrypoints/engine.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index d43f54fa5ac..901e0d5f0c8 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -615,7 +615,8 @@ def _launch_subprocesses( return tokenizer_manager, scheduler_info -def _compute_initial_expert_location_metadata(server_args: ServerArgs, eplb_manager: EPLBManager): +def _compute_initial_expert_location_metadata(server_args: ServerArgs, + eplb_manager: EPLBManager) -> ExpertLocationMetadata: if server_args.init_expert_location is not None: return TODO if server_args.enable_eplb: From 00e2048dd8f34bc2b8aa7f3b805e93c0c5042fd7 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 20:51:47 +0800 Subject: [PATCH 0412/1089] more --- python/sglang/srt/managers/expert_location.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index bfdbcc257bc..4967eae8e21 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -20,7 +20,7 @@ class ExpertLocationMetadata: # -------------------------------- construction and mutation ------------------------------------ @staticmethod - def init_trivial(num_layers: int, num_logical_experts: int): + def init_trivial(): """Trivial location - logical expert i corresponds to physical expert i""" return ExpertLocationMetadata( num_layers=num_layers, @@ -40,6 +40,14 @@ def init_trivial(num_layers: int, num_logical_experts: int): ).repeat(num_layers, 1)[..., None], ) + @staticmethod + def init_by_mapping(): + return TODO + + @staticmethod + def init_by_eplb(): + return TODO + # -------------------------------- usage ------------------------------------ def local_physical_to_physical(self, rank: int, local_physical_expert_index: int): From 82596368ad1e1a24a87257c8bfa432c16c37f79c Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 20:52:19 +0800 Subject: [PATCH 0413/1089] more --- python/sglang/srt/entrypoints/engine.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 901e0d5f0c8..670c291f1ab 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -20,6 +20,7 @@ import asyncio import atexit import dataclasses +import json import logging import multiprocessing as mp import os @@ -617,8 +618,8 @@ def _launch_subprocesses( def _compute_initial_expert_location_metadata(server_args: ServerArgs, eplb_manager: EPLBManager) -> ExpertLocationMetadata: - if server_args.init_expert_location is not None: - return TODO + if (data := server_args.init_expert_location) is not None: + return ExpertLocationMetadata.init_by_mapping(**json.loads(data)) if server_args.enable_eplb: return TODO(eplb_manager) - return TODO + return ExpertLocationMetadata.init_trivial() From d930b74a2e1d4916ec77202aed871c15cd25a306 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 20:52:47 +0800 Subject: [PATCH 0414/1089] more --- python/sglang/srt/entrypoints/engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 670c291f1ab..178eeb76cfb 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -621,5 +621,5 @@ def _compute_initial_expert_location_metadata(server_args: ServerArgs, if (data := server_args.init_expert_location) is not None: return ExpertLocationMetadata.init_by_mapping(**json.loads(data)) if server_args.enable_eplb: - return TODO(eplb_manager) + return eplb_manager.compute_expert_location_metadata() return ExpertLocationMetadata.init_trivial() From 67bdc44fc9d0b907e8093af1e29e5e9e851c3ee2 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 20:55:15 +0800 Subject: [PATCH 0415/1089] more --- python/sglang/srt/managers/eplb_manager.py | 11 ----- python/sglang/srt/managers/expert_location.py | 45 ++++++++++++------- 2 files changed, 29 insertions(+), 27 deletions(-) diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index 6716bfae399..5ca67e56715 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -50,17 +50,6 @@ class _EPLBManagerNoop(EPLBManager): def _compute_expert_location_metadata_raw( server_args: ServerArgs, logical_count: torch.Tensor ): - model_config = ModelConfig.from_server_args(server_args) - model_config_for_expert_location = ModelConfigForExpertLocation.from_model_config( - model_config - ) - - num_physical_experts = model_config_for_expert_location.num_logical_experts + server_args.ep_num_redundant_experts - # TODO consider case when DP attention is disabled and DP > 1 - world_size = server_args.tp_size - assert num_physical_experts % world_size == 0 - num_local_physical_experts = num_physical_experts // world_size - physical_to_logical_map, logical_to_all_physical_map, expert_count = ( deepseek_eplb.rebalance_experts( weight=logical_count, diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 4967eae8e21..9de6e136698 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -5,6 +5,7 @@ from sglang.srt.configs.model_config import ModelConfig from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_loader import get_model_architecture +from sglang.srt.server_args import ServerArgs @dataclass @@ -20,33 +21,45 @@ class ExpertLocationMetadata: # -------------------------------- construction and mutation ------------------------------------ @staticmethod - def init_trivial(): + def init_trivial(server_args: ServerArgs): """Trivial location - logical expert i corresponds to physical expert i""" + common = ExpertLocationMetadata._init_common(server_args) + physical_to_logical_map = torch.arange(0, num_physical_experts).repeat(num_layers, 1) % num_logical_experts + return ExpertLocationMetadata.init_by_mapping(server_args, physical_to_logical_map=physical_to_logical_map) + + @staticmethod + def init_by_mapping(server_args: ServerArgs, physical_to_logical_map): + if not isinstance(physical_to_logical_map, torch.Tensor): + physical_to_logical_map = torch.tensor(physical_to_logical_map) + return ExpertLocationMetadata( num_layers=num_layers, num_logical_experts=num_logical_experts, num_local_physical_experts=num_local_physical_experts, - physical_to_logical_map=torch.arange(0, num_physical_experts).repeat( - num_layers, 1 - ) - % num_logical_experts, - # Throw away the redundant experts here - highly inefficient, but we do not care since we will - # use EPLB distribution logic - logical_to_all_physical_map=torch.arange(0, num_logical_experts).repeat( - num_layers, 1 - )[..., None], - logical_to_rank_dispatch_physical_map=torch.arange( - 0, num_logical_experts - ).repeat(num_layers, 1)[..., None], + physical_to_logical_map=physical_to_logical_map, + logical_to_all_physical_map=TODO, + logical_to_rank_dispatch_physical_map=TODO, ) @staticmethod - def init_by_mapping(): + def init_by_eplb(): return TODO @staticmethod - def init_by_eplb(): - return TODO + def _init_common(server_args: ServerArgs): + model_config = ModelConfig.from_server_args(server_args) + model_config_for_expert_location = ModelConfigForExpertLocation.from_model_config(model_config) + + num_physical_experts = model_config_for_expert_location.num_logical_experts + server_args.ep_num_redundant_experts + # TODO consider case when DP attention is disabled and DP > 1 + world_size = server_args.tp_size + assert num_physical_experts % world_size == 0 + num_local_physical_experts = num_physical_experts // world_size + + return dict( + model_config_for_expert_location=model_config_for_expert_location, + num_local_physical_experts=num_local_physical_experts, + ) # -------------------------------- usage ------------------------------------ From aa415f4a8c5c7be09f696c887c230fc7ae569af2 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 20:55:45 +0800 Subject: [PATCH 0416/1089] more --- python/sglang/srt/managers/expert_location.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 9de6e136698..249ab853bbc 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -32,10 +32,12 @@ def init_by_mapping(server_args: ServerArgs, physical_to_logical_map): if not isinstance(physical_to_logical_map, torch.Tensor): physical_to_logical_map = torch.tensor(physical_to_logical_map) + common = ExpertLocationMetadata._init_common(server_args) + model_config_for_expert_location = common["model_config_for_expert_location"] return ExpertLocationMetadata( - num_layers=num_layers, - num_logical_experts=num_logical_experts, - num_local_physical_experts=num_local_physical_experts, + num_layers=model_config_for_expert_location.num_layers, + num_logical_experts=model_config_for_expert_location.num_logical_experts, + num_local_physical_experts=common["num_local_physical_experts"], physical_to_logical_map=physical_to_logical_map, logical_to_all_physical_map=TODO, logical_to_rank_dispatch_physical_map=TODO, From e4fbf0968c109a7ea7fe5930585ada4f64da257a Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 20:56:27 +0800 Subject: [PATCH 0417/1089] more --- python/sglang/srt/managers/expert_location.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 249ab853bbc..ef8c861868b 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -24,7 +24,12 @@ class ExpertLocationMetadata: def init_trivial(server_args: ServerArgs): """Trivial location - logical expert i corresponds to physical expert i""" common = ExpertLocationMetadata._init_common(server_args) - physical_to_logical_map = torch.arange(0, num_physical_experts).repeat(num_layers, 1) % num_logical_experts + num_physical_experts = common["num_physical_experts"] + model_config_for_expert_location = common["model_config_for_expert_location"] + + physical_to_logical_map = torch.arange(0, num_physical_experts).repeat( + model_config_for_expert_location.num_layers, 1) % model_config_for_expert_location.num_logical_experts + return ExpertLocationMetadata.init_by_mapping(server_args, physical_to_logical_map=physical_to_logical_map) @staticmethod @@ -60,6 +65,7 @@ def _init_common(server_args: ServerArgs): return dict( model_config_for_expert_location=model_config_for_expert_location, + num_physical_experts=num_physical_experts, num_local_physical_experts=num_local_physical_experts, ) From bbd416f6890d2be355ee5b80976813c352e4dc06 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 20:57:02 +0800 Subject: [PATCH 0418/1089] more --- python/sglang/srt/managers/eplb_manager.py | 45 ------------------- python/sglang/srt/managers/expert_location.py | 43 +++++++++++++++++- 2 files changed, 42 insertions(+), 46 deletions(-) diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index 5ca67e56715..19b295c53ff 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -46,48 +46,3 @@ def compute_expert_location_metadata(self): class _EPLBManagerNoop(EPLBManager): pass - -def _compute_expert_location_metadata_raw( - server_args: ServerArgs, logical_count: torch.Tensor -): - physical_to_logical_map, logical_to_all_physical_map, expert_count = ( - deepseek_eplb.rebalance_experts( - weight=logical_count, - num_replicas=num_physical_experts, - num_groups=model_config_for_expert_location.num_groups, - num_nodes=server_args.nnodes, - num_gpus=world_size, - ) - ) - return ExpertLocationMetadata( - num_layers=model_config_for_expert_location.num_layers, - num_local_physical_experts=num_local_physical_experts, - num_logical_experts=model_config_for_expert_location.num_logical_experts, - physical_to_logical_map=physical_to_logical_map, - logical_to_all_physical_map=logical_to_all_physical_map, - logical_to_rank_dispatch_physical_map=_compute_logical_to_rank_dispatch_physical_map( - logical_to_all_physical_map, num_gpus=world_size, - ), - ) - - -def _compute_logical_to_rank_dispatch_physical_map( - logical_to_all_physical_map: torch.Tensor, - num_gpus: int, -): - # TODO maybe improve this algorithm (e.g. ensure it is really balanced) - # This is rarely called, so we use for loops for maximum clarity - - r = random.Random() - - num_layers, num_logical_experts, _ = logical_to_all_physical_map.shape - logical_to_rank_dispatch_physical_map = torch.zeros((num_gpus, num_layers, num_logical_experts)) - - for layer_id in range(num_layers): - for logical_expert_id in range(num_logical_experts): - for gpu_id in range(num_gpus): - candidate_values = ExpertLocationMetadata.logical_to_all_physical_raw( - logical_to_all_physical_map, layer_id, logical_expert_id) - logical_to_rank_dispatch_physical_map[gpu_id, layer_id, logical_expert_id] = r.choice(candidate_values) - - return logical_to_rank_dispatch_physical_map diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index ef8c861868b..9a0bd3f2839 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -1,3 +1,4 @@ +import random from dataclasses import dataclass from typing import List, Optional @@ -50,7 +51,25 @@ def init_by_mapping(server_args: ServerArgs, physical_to_logical_map): @staticmethod def init_by_eplb(): - return TODO + physical_to_logical_map, logical_to_all_physical_map, expert_count = ( + deepseek_eplb.rebalance_experts( + weight=logical_count, + num_replicas=num_physical_experts, + num_groups=model_config_for_expert_location.num_groups, + num_nodes=server_args.nnodes, + num_gpus=world_size, + ) + ) + return ExpertLocationMetadata( + num_layers=model_config_for_expert_location.num_layers, + num_local_physical_experts=num_local_physical_experts, + num_logical_experts=model_config_for_expert_location.num_logical_experts, + physical_to_logical_map=physical_to_logical_map, + logical_to_all_physical_map=logical_to_all_physical_map, + logical_to_rank_dispatch_physical_map=_compute_logical_to_rank_dispatch_physical_map( + logical_to_all_physical_map, num_gpus=world_size, + ), + ) @staticmethod def _init_common(server_args: ServerArgs): @@ -95,6 +114,28 @@ def logical_to_all_physical_raw( ] +def _compute_logical_to_rank_dispatch_physical_map( + logical_to_all_physical_map: torch.Tensor, + num_gpus: int, +): + # TODO maybe improve this algorithm (e.g. ensure it is really balanced) + # This is rarely called, so we use for loops for maximum clarity + + r = random.Random() + + num_layers, num_logical_experts, _ = logical_to_all_physical_map.shape + logical_to_rank_dispatch_physical_map = torch.zeros((num_gpus, num_layers, num_logical_experts)) + + for layer_id in range(num_layers): + for logical_expert_id in range(num_logical_experts): + for gpu_id in range(num_gpus): + candidate_values = ExpertLocationMetadata.logical_to_all_physical_raw( + logical_to_all_physical_map, layer_id, logical_expert_id) + logical_to_rank_dispatch_physical_map[gpu_id, layer_id, logical_expert_id] = r.choice(candidate_values) + + return logical_to_rank_dispatch_physical_map + + @dataclass class ModelConfigForExpertLocation: num_layers: int From a3eff17706db149d2e1c21740b9e860aef2596e2 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 20:58:11 +0800 Subject: [PATCH 0419/1089] more --- python/sglang/srt/managers/expert_location.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 9a0bd3f2839..c53f07e0c79 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -4,6 +4,7 @@ import torch from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.managers import deepseek_eplb from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_loader import get_model_architecture from sglang.srt.server_args import ServerArgs @@ -50,24 +51,28 @@ def init_by_mapping(server_args: ServerArgs, physical_to_logical_map): ) @staticmethod - def init_by_eplb(): + def init_by_eplb(server_args: ServerArgs, logical_count: torch.Tensor): + common = ExpertLocationMetadata._init_common(server_args) + model_config_for_expert_location = common["model_config_for_expert_location"] + physical_to_logical_map, logical_to_all_physical_map, expert_count = ( deepseek_eplb.rebalance_experts( weight=logical_count, - num_replicas=num_physical_experts, + num_replicas=common["num_physical_experts"], num_groups=model_config_for_expert_location.num_groups, num_nodes=server_args.nnodes, - num_gpus=world_size, + num_gpus=common["world_size"], ) ) + return ExpertLocationMetadata( num_layers=model_config_for_expert_location.num_layers, - num_local_physical_experts=num_local_physical_experts, num_logical_experts=model_config_for_expert_location.num_logical_experts, + num_local_physical_experts=common["num_local_physical_experts"], physical_to_logical_map=physical_to_logical_map, logical_to_all_physical_map=logical_to_all_physical_map, logical_to_rank_dispatch_physical_map=_compute_logical_to_rank_dispatch_physical_map( - logical_to_all_physical_map, num_gpus=world_size, + logical_to_all_physical_map, num_gpus=common["world_size"], ), ) @@ -86,6 +91,7 @@ def _init_common(server_args: ServerArgs): model_config_for_expert_location=model_config_for_expert_location, num_physical_experts=num_physical_experts, num_local_physical_experts=num_local_physical_experts, + world_size=world_size, ) # -------------------------------- usage ------------------------------------ From de47475c79902927ae705a4b26f6316c0082ea5a Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 20:58:18 +0800 Subject: [PATCH 0420/1089] more --- python/sglang/srt/managers/expert_location.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index c53f07e0c79..818331f0eb9 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -41,6 +41,7 @@ def init_by_mapping(server_args: ServerArgs, physical_to_logical_map): common = ExpertLocationMetadata._init_common(server_args) model_config_for_expert_location = common["model_config_for_expert_location"] + return ExpertLocationMetadata( num_layers=model_config_for_expert_location.num_layers, num_logical_experts=model_config_for_expert_location.num_logical_experts, From 61c8afe4faebe389b73d08b7a3cb0087b2018323 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 20:58:59 +0800 Subject: [PATCH 0421/1089] more --- python/sglang/srt/managers/expert_location.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 818331f0eb9..9204aee8f6e 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -41,14 +41,17 @@ def init_by_mapping(server_args: ServerArgs, physical_to_logical_map): common = ExpertLocationMetadata._init_common(server_args) model_config_for_expert_location = common["model_config_for_expert_location"] + logical_to_all_physical_map = TODO return ExpertLocationMetadata( num_layers=model_config_for_expert_location.num_layers, num_logical_experts=model_config_for_expert_location.num_logical_experts, num_local_physical_experts=common["num_local_physical_experts"], physical_to_logical_map=physical_to_logical_map, - logical_to_all_physical_map=TODO, - logical_to_rank_dispatch_physical_map=TODO, + logical_to_all_physical_map=logical_to_all_physical_map, + logical_to_rank_dispatch_physical_map=_compute_logical_to_rank_dispatch_physical_map( + logical_to_all_physical_map, num_gpus=common["world_size"], + ), ) @staticmethod From 51f7ac7dcf3cab1e2208785069867386655c2a64 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 21:00:30 +0800 Subject: [PATCH 0422/1089] more --- python/sglang/srt/managers/expert_location.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 9204aee8f6e..6e21e190458 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -31,17 +31,18 @@ def init_trivial(server_args: ServerArgs): physical_to_logical_map = torch.arange(0, num_physical_experts).repeat( model_config_for_expert_location.num_layers, 1) % model_config_for_expert_location.num_logical_experts + logical_to_all_physical_map = TODO - return ExpertLocationMetadata.init_by_mapping(server_args, physical_to_logical_map=physical_to_logical_map) + return ExpertLocationMetadata.init_by_mapping(server_args, physical_to_logical_map=physical_to_logical_map, + logical_to_all_physical_map=logical_to_all_physical_map) @staticmethod - def init_by_mapping(server_args: ServerArgs, physical_to_logical_map): + def init_by_mapping(server_args: ServerArgs, physical_to_logical_map, logical_to_all_physical_map): if not isinstance(physical_to_logical_map, torch.Tensor): physical_to_logical_map = torch.tensor(physical_to_logical_map) common = ExpertLocationMetadata._init_common(server_args) model_config_for_expert_location = common["model_config_for_expert_location"] - logical_to_all_physical_map = TODO return ExpertLocationMetadata( num_layers=model_config_for_expert_location.num_layers, From c1c516108c8fe8808c9ecd793fee40915843ca28 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 21:01:30 +0800 Subject: [PATCH 0423/1089] more --- python/sglang/srt/managers/expert_location.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 6e21e190458..dfe930ec064 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -28,10 +28,13 @@ def init_trivial(server_args: ServerArgs): common = ExpertLocationMetadata._init_common(server_args) num_physical_experts = common["num_physical_experts"] model_config_for_expert_location = common["model_config_for_expert_location"] + num_layers = model_config_for_expert_location.num_layers + num_logical_experts = model_config_for_expert_location.num_logical_experts - physical_to_logical_map = torch.arange(0, num_physical_experts).repeat( - model_config_for_expert_location.num_layers, 1) % model_config_for_expert_location.num_logical_experts - logical_to_all_physical_map = TODO + physical_to_logical_map = torch.arange(0, num_physical_experts).repeat(num_layers, 1) % num_logical_experts + # Throw away the redundant experts here - highly inefficient, but we do not care since we will + # use EPLB distribution logic + logical_to_all_physical_map = torch.arange(0, num_logical_experts).repeat(num_layers, 1)[..., None] return ExpertLocationMetadata.init_by_mapping(server_args, physical_to_logical_map=physical_to_logical_map, logical_to_all_physical_map=logical_to_all_physical_map) From d4531d1c31ef75e63f1b4b1bad2528139cbd00e9 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 21:01:40 +0800 Subject: [PATCH 0424/1089] more --- python/sglang/srt/managers/expert_location.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index dfe930ec064..6095e3549c3 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -32,8 +32,7 @@ def init_trivial(server_args: ServerArgs): num_logical_experts = model_config_for_expert_location.num_logical_experts physical_to_logical_map = torch.arange(0, num_physical_experts).repeat(num_layers, 1) % num_logical_experts - # Throw away the redundant experts here - highly inefficient, but we do not care since we will - # use EPLB distribution logic + # Highly inefficient, but we do not care since we will use EPLB distribution logic logical_to_all_physical_map = torch.arange(0, num_logical_experts).repeat(num_layers, 1)[..., None] return ExpertLocationMetadata.init_by_mapping(server_args, physical_to_logical_map=physical_to_logical_map, From a85ca16ca45ee4912cdb12acb023489e10627710 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 21:02:15 +0800 Subject: [PATCH 0425/1089] more --- python/sglang/srt/entrypoints/engine.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 178eeb76cfb..673f73b0985 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -619,7 +619,8 @@ def _launch_subprocesses( def _compute_initial_expert_location_metadata(server_args: ServerArgs, eplb_manager: EPLBManager) -> ExpertLocationMetadata: if (data := server_args.init_expert_location) is not None: - return ExpertLocationMetadata.init_by_mapping(**json.loads(data)) + # TODO We may want to allow users to not provide `logical_to_all_physical_map` if this API is frequently used + return ExpertLocationMetadata.init_by_mapping(server_args, **json.loads(data)) if server_args.enable_eplb: return eplb_manager.compute_expert_location_metadata() return ExpertLocationMetadata.init_trivial() From e6009c6124d0eb28d0dbc9399d4e81e21dadfe0e Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 21:03:10 +0800 Subject: [PATCH 0426/1089] more --- python/sglang/srt/entrypoints/engine.py | 2 +- python/sglang/srt/managers/eplb_manager.py | 9 +++------ 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 673f73b0985..47b604e176e 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -623,4 +623,4 @@ def _compute_initial_expert_location_metadata(server_args: ServerArgs, return ExpertLocationMetadata.init_by_mapping(server_args, **json.loads(data)) if server_args.enable_eplb: return eplb_manager.compute_expert_location_metadata() - return ExpertLocationMetadata.init_trivial() + return ExpertLocationMetadata.init_trivial(server_args) diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index 19b295c53ff..aee2a38a018 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -1,7 +1,5 @@ -import random from typing import TYPE_CHECKING, Optional -import torch from sglang.srt.configs.model_config import ModelConfig from sglang.srt.managers import deepseek_eplb from sglang.srt.managers.expert_distribution_storage import ExpertDistributionStorage @@ -27,7 +25,7 @@ def __init__(self): self.tokenizer_manager: Optional[TokenizerManager] = None def compute_expert_location_metadata(self) -> ExpertLocationMetadata: - return ExpertLocationMetadata.init_trivial(TODO) + raise NotImplementedError class _EPLBManagerReal(EPLBManager): @@ -39,10 +37,9 @@ def __init__(self, server_args: ServerArgs): def compute_expert_location_metadata(self): logical_count = self._expert_distribution_storage.get_last_snapshot() if logical_count is None: - return super().compute_expert_location_metadata() - return _compute_expert_location_metadata_raw(self._server_args, logical_count) + return ExpertLocationMetadata.init_trivial(self._server_args) + return ExpertLocationMetadata.init_by_eplb(self._server_args, logical_count=logical_count) class _EPLBManagerNoop(EPLBManager): pass - From 0643d12e605d476fc1feeb696a804fa26ee30497 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 21:03:24 +0800 Subject: [PATCH 0427/1089] more --- python/sglang/srt/managers/eplb_manager.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index aee2a38a018..d1729e0c672 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -32,7 +32,8 @@ class _EPLBManagerReal(EPLBManager): def __init__(self, server_args: ServerArgs): super().__init__() self._server_args = server_args - self._expert_distribution_storage = ExpertDistributionStorage() + TODO_init_later + self._expert_distribution_storage = ExpertDistributionStorage(TODO) def compute_expert_location_metadata(self): logical_count = self._expert_distribution_storage.get_last_snapshot() From a29b0b436bf1ccf0d1032ea873d6579711c9c988 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 21:04:11 +0800 Subject: [PATCH 0428/1089] more --- python/sglang/srt/managers/expert_distribution_storage.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution_storage.py b/python/sglang/srt/managers/expert_distribution_storage.py index a1f80cdd454..a5358c71a67 100644 --- a/python/sglang/srt/managers/expert_distribution_storage.py +++ b/python/sglang/srt/managers/expert_distribution_storage.py @@ -2,8 +2,8 @@ class ExpertDistributionStorage: - def __init__(self, dir_data: Path): - TODO + def __init__(self, dir_data): + self._dir_data = Path(dir_data) def save_current(self): TODO From 1114ab3be33c7b08dfab7d7646f7a4003fe4133c Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 21:05:38 +0800 Subject: [PATCH 0429/1089] more --- .../srt/managers/expert_distribution_storage.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution_storage.py b/python/sglang/srt/managers/expert_distribution_storage.py index a5358c71a67..f6055e2868c 100644 --- a/python/sglang/srt/managers/expert_distribution_storage.py +++ b/python/sglang/srt/managers/expert_distribution_storage.py @@ -1,12 +1,19 @@ from pathlib import Path +from sglang.srt.managers.tokenizer_manager import TokenizerManager + class ExpertDistributionStorage: - def __init__(self, dir_data): + def __init__(self, dir_data, tokenizer_manager: TokenizerManager): self._dir_data = Path(dir_data) + self._tokenizer_manager = tokenizer_manager + + async def initialize(self): + await self._tokenizer_manager.start_expert_distribution_record() - def save_current(self): - TODO + async def save_current(self): + data = await self._tokenizer_manager.dump_expert_distribution_record() + TODO_write_data - def get_last_snapshot(self): - return TODO + async def get_last_snapshot(self): + return TODO_read_data From 8ccdb0832ce99897ddce8c0535dc9d5134bb3490 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 21:06:15 +0800 Subject: [PATCH 0430/1089] more --- python/sglang/srt/managers/eplb_manager.py | 7 ++++++- python/sglang/srt/managers/expert_distribution_storage.py | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index d1729e0c672..fdf6cd6b543 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -24,6 +24,9 @@ def init_new(server_args: ServerArgs): def __init__(self): self.tokenizer_manager: Optional[TokenizerManager] = None + async def initialize(self): + pass + def compute_expert_location_metadata(self) -> ExpertLocationMetadata: raise NotImplementedError @@ -32,8 +35,10 @@ class _EPLBManagerReal(EPLBManager): def __init__(self, server_args: ServerArgs): super().__init__() self._server_args = server_args - TODO_init_later + + async def initialize(self): self._expert_distribution_storage = ExpertDistributionStorage(TODO) + await self._expert_distribution_storage.initialize() def compute_expert_location_metadata(self): logical_count = self._expert_distribution_storage.get_last_snapshot() diff --git a/python/sglang/srt/managers/expert_distribution_storage.py b/python/sglang/srt/managers/expert_distribution_storage.py index f6055e2868c..dccb238f9c0 100644 --- a/python/sglang/srt/managers/expert_distribution_storage.py +++ b/python/sglang/srt/managers/expert_distribution_storage.py @@ -15,5 +15,5 @@ async def save_current(self): data = await self._tokenizer_manager.dump_expert_distribution_record() TODO_write_data - async def get_last_snapshot(self): + def get_last_snapshot(self): return TODO_read_data From 4d21a244323209e204eca9bcc7d8a8026c4222ba Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 21:07:14 +0800 Subject: [PATCH 0431/1089] more --- python/sglang/srt/managers/eplb_manager.py | 12 +++++------- python/sglang/srt/managers/tokenizer_manager.py | 7 +++---- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index fdf6cd6b543..92d1d73d3a1 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from sglang.srt.configs.model_config import ModelConfig from sglang.srt.managers import deepseek_eplb @@ -21,10 +21,7 @@ def init_new(server_args: ServerArgs): else: return _EPLBManagerNoop() - def __init__(self): - self.tokenizer_manager: Optional[TokenizerManager] = None - - async def initialize(self): + async def initialize(self, tokenizer_manager: TokenizerManager): pass def compute_expert_location_metadata(self) -> ExpertLocationMetadata: @@ -36,8 +33,9 @@ def __init__(self, server_args: ServerArgs): super().__init__() self._server_args = server_args - async def initialize(self): - self._expert_distribution_storage = ExpertDistributionStorage(TODO) + async def initialize(self, tokenizer_manager: TokenizerManager): + self._expert_distribution_storage = ExpertDistributionStorage(dir_data=TODO, + tokenizer_manager=tokenizer_manager) await self._expert_distribution_storage.initialize() def compute_expert_location_metadata(self): diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 95ee879bbf0..4d9d65709d9 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -45,7 +45,6 @@ import zmq import zmq.asyncio from fastapi import BackgroundTasks - from sglang.srt.aio_rwlock import RWLock from sglang.srt.configs.model_config import ModelConfig from sglang.srt.disaggregation.conn import KVBootstrapServer @@ -203,7 +202,7 @@ def __init__( revision=server_args.revision, ) - eplb_manager.tokenizer_manager = self + eplb_manager.initialize(self) self.eplb_manager = eplb_manager # Store states @@ -967,8 +966,8 @@ def _handle_batch_output( elif isinstance(recv_obj, BatchTokenIDOut): if self.server_args.stream_output and state.obj.stream: output_token_ids = recv_obj.output_ids[i][ - state.last_output_offset : - ] + state.last_output_offset: + ] state.last_output_offset = len(recv_obj.output_ids[i]) else: output_token_ids = recv_obj.output_ids[i] From 47ee3c9878e54c25e09beb1005fb301df24fb18b Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 21:09:09 +0800 Subject: [PATCH 0432/1089] more --- .../srt/managers/expert_distribution_storage.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution_storage.py b/python/sglang/srt/managers/expert_distribution_storage.py index dccb238f9c0..b51bbe62cf6 100644 --- a/python/sglang/srt/managers/expert_distribution_storage.py +++ b/python/sglang/srt/managers/expert_distribution_storage.py @@ -1,4 +1,7 @@ +import json +import time from pathlib import Path +from typing import Any, Optional, Dict from sglang.srt.managers.tokenizer_manager import TokenizerManager @@ -13,7 +16,11 @@ async def initialize(self): async def save_current(self): data = await self._tokenizer_manager.dump_expert_distribution_record() - TODO_write_data + (self._dir_data / f"{time.time_ns()}.json").write_text(json.dumps(data)) - def get_last_snapshot(self): - return TODO_read_data + def get_last_snapshot(self) -> Optional[Dict[str, Any]]: + paths = sorted(list(self._dir_data.glob("*.json")), key=lambda p: int(p.stem)) + if len(paths) == 0: + return None + path = paths[-1] + return json.loads(path.read_text()) From 7059e3ec4551386a281195bcd1f483a1c2b96c42 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 21:09:51 +0800 Subject: [PATCH 0433/1089] more --- python/sglang/srt/managers/expert_distribution_storage.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_distribution_storage.py b/python/sglang/srt/managers/expert_distribution_storage.py index b51bbe62cf6..ce0f6cabc7e 100644 --- a/python/sglang/srt/managers/expert_distribution_storage.py +++ b/python/sglang/srt/managers/expert_distribution_storage.py @@ -1,10 +1,13 @@ import json +import logging import time from pathlib import Path from typing import Any, Optional, Dict from sglang.srt.managers.tokenizer_manager import TokenizerManager +logger = logging.getLogger(__name__) + class ExpertDistributionStorage: def __init__(self, dir_data, tokenizer_manager: TokenizerManager): @@ -16,11 +19,14 @@ async def initialize(self): async def save_current(self): data = await self._tokenizer_manager.dump_expert_distribution_record() - (self._dir_data / f"{time.time_ns()}.json").write_text(json.dumps(data)) + path = self._dir_data / f"{time.time_ns()}.json" + logger.info(f"save_current to path {path}") + path.write_text(json.dumps(data)) def get_last_snapshot(self) -> Optional[Dict[str, Any]]: paths = sorted(list(self._dir_data.glob("*.json")), key=lambda p: int(p.stem)) if len(paths) == 0: return None path = paths[-1] + logger.info(f"get_last_snapshot choose path {path}") return json.loads(path.read_text()) From be693e77983403e27a812cfe2d17afd0ff6e667a Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 21:10:12 +0800 Subject: [PATCH 0434/1089] more --- python/sglang/srt/managers/eplb_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index 92d1d73d3a1..ffeede2ceb8 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -39,7 +39,7 @@ async def initialize(self, tokenizer_manager: TokenizerManager): await self._expert_distribution_storage.initialize() def compute_expert_location_metadata(self): - logical_count = self._expert_distribution_storage.get_last_snapshot() + logical_count = self._expert_distribution_storage.get_last_snapshot()["logical_count"] if logical_count is None: return ExpertLocationMetadata.init_trivial(self._server_args) return ExpertLocationMetadata.init_by_eplb(self._server_args, logical_count=logical_count) From ae9bac8a8f4b865409c7290e264c9ad7b9868a53 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 21:10:42 +0800 Subject: [PATCH 0435/1089] more --- python/sglang/srt/managers/expert_distribution_storage.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/sglang/srt/managers/expert_distribution_storage.py b/python/sglang/srt/managers/expert_distribution_storage.py index ce0f6cabc7e..23da23b7a0a 100644 --- a/python/sglang/srt/managers/expert_distribution_storage.py +++ b/python/sglang/srt/managers/expert_distribution_storage.py @@ -13,6 +13,8 @@ class ExpertDistributionStorage: def __init__(self, dir_data, tokenizer_manager: TokenizerManager): self._dir_data = Path(dir_data) self._tokenizer_manager = tokenizer_manager + if not self._dir_data.exists(): + self._dir_data.mkdir(parents=True, exist_ok=True) async def initialize(self): await self._tokenizer_manager.start_expert_distribution_record() From dc1e050d796ab88b93856cb0b091547fda4d71a3 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 21:11:32 +0800 Subject: [PATCH 0436/1089] more --- python/sglang/srt/server_args.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 920889ae18b..e86e7892592 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -165,6 +165,7 @@ class ServerArgs: ep_num_redundant_experts: int = 0 init_expert_location: Optional[str] = None enable_eplb: bool = False + eplb_storage_dir: Optional[str] = None enable_torch_compile: bool = False torch_compile_max_bs: int = 32 cuda_graph_max_bs: Optional[int] = None @@ -1105,23 +1106,29 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Enabling DeepEP MoE implementation for EP MoE.", ) - parser.add_argument( - "--init-expert-location", - type=str, - default=ServerArgs.init_expert_location, - help="Initial location of EP experts.", - ) parser.add_argument( "--ep-num-redundant-experts", type=int, default=ServerArgs.ep_num_redundant_experts, help="Allocate this number of redundant experts in expert parallel.", ) + parser.add_argument( + "--init-expert-location", + type=str, + default=ServerArgs.init_expert_location, + help="Initial location of EP experts.", + ) parser.add_argument( "--enable-eplb", action="store_true", help="Enable EPLB algorithm", ) + parser.add_argument( + "--eplb-cache-dir", + type=str, + default=ServerArgs.eplb_storage_dir, + help="Storage directory of EPLB subsystem.", + ) parser.add_argument( "--deepep-mode", type=str, From f178507ef57496929e214628abe9fdbe49fdf4a6 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 21:12:09 +0800 Subject: [PATCH 0437/1089] more --- python/sglang/srt/managers/eplb_manager.py | 7 +++++-- python/sglang/srt/server_args.py | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index ffeede2ceb8..8d4be36fecf 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -1,3 +1,4 @@ +from pathlib import Path from typing import TYPE_CHECKING from sglang.srt.configs.model_config import ModelConfig @@ -34,8 +35,10 @@ def __init__(self, server_args: ServerArgs): self._server_args = server_args async def initialize(self, tokenizer_manager: TokenizerManager): - self._expert_distribution_storage = ExpertDistributionStorage(dir_data=TODO, - tokenizer_manager=tokenizer_manager) + self._expert_distribution_storage = ExpertDistributionStorage( + dir_data=Path(self._server_args.eplb_storage_dir) / "expert_distribution_storage", + tokenizer_manager=tokenizer_manager, + ) await self._expert_distribution_storage.initialize() def compute_expert_location_metadata(self): diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index e86e7892592..bc302468bd7 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -165,7 +165,7 @@ class ServerArgs: ep_num_redundant_experts: int = 0 init_expert_location: Optional[str] = None enable_eplb: bool = False - eplb_storage_dir: Optional[str] = None + eplb_storage_dir: str = "/tmp/eplb_storage" enable_torch_compile: bool = False torch_compile_max_bs: int = 32 cuda_graph_max_bs: Optional[int] = None From 4ebdffc1e7d094dafe2d5d61c04a45845d37e8e6 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 21:13:23 +0800 Subject: [PATCH 0438/1089] more --- python/sglang/srt/entrypoints/engine.py | 2 +- python/sglang/srt/managers/eplb_manager.py | 15 --------------- python/sglang/srt/managers/tokenizer_manager.py | 8 +++++--- 3 files changed, 6 insertions(+), 19 deletions(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 47b604e176e..31498932558 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -499,7 +499,7 @@ def _launch_subprocesses( server_args.model_path, server_args.tokenizer_path ) - eplb_manager = EPLBManager.init_new(server_args) + eplb_manager = EPLBManager(server_args) if server_args.enable_eplb else None expert_location_metadata = _compute_initial_expert_location_metadata(server_args, eplb_manager) scheduler_procs = [] diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index 8d4be36fecf..6b44a4aa396 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -15,21 +15,6 @@ class EPLBManager: - @staticmethod - def init_new(server_args: ServerArgs): - if server_args.enable_eplb: - return _EPLBManagerReal(server_args) - else: - return _EPLBManagerNoop() - - async def initialize(self, tokenizer_manager: TokenizerManager): - pass - - def compute_expert_location_metadata(self) -> ExpertLocationMetadata: - raise NotImplementedError - - -class _EPLBManagerReal(EPLBManager): def __init__(self, server_args: ServerArgs): super().__init__() self._server_args = server_args diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 4d9d65709d9..23679f4f94e 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -140,7 +140,7 @@ def __init__( server_args: ServerArgs, port_args: PortArgs, expert_location_metadata: ExpertLocationMetadata, - eplb_manager: EPLBManager, + eplb_manager: Optional[EPLBManager], ): # Parse args self.server_args = server_args @@ -202,8 +202,10 @@ def __init__( revision=server_args.revision, ) - eplb_manager.initialize(self) - self.eplb_manager = eplb_manager + if eplb_manager is not None: + TODO_async + eplb_manager.initialize(self) + self.eplb_manager = eplb_manager # Store states self.no_create_loop = False From a0eeeba59e8f6e4d735e384503813d24454b19b3 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 21:14:00 +0800 Subject: [PATCH 0439/1089] more --- python/sglang/srt/managers/eplb_manager.py | 5 ++++- python/sglang/srt/managers/tokenizer_manager.py | 1 - 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index 6b44a4aa396..f63d2d866e2 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -19,12 +19,15 @@ def __init__(self, server_args: ServerArgs): super().__init__() self._server_args = server_args - async def initialize(self, tokenizer_manager: TokenizerManager): + def initialize(self, tokenizer_manager: TokenizerManager): self._expert_distribution_storage = ExpertDistributionStorage( dir_data=Path(self._server_args.eplb_storage_dir) / "expert_distribution_storage", tokenizer_manager=tokenizer_manager, ) + + async def handle_loop(self): await self._expert_distribution_storage.initialize() + TODO def compute_expert_location_metadata(self): logical_count = self._expert_distribution_storage.get_last_snapshot()["logical_count"] diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 23679f4f94e..d58055cdb45 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -203,7 +203,6 @@ def __init__( ) if eplb_manager is not None: - TODO_async eplb_manager.initialize(self) self.eplb_manager = eplb_manager From 44915c77ca69725ede5963d6cfd281b512dc3ca9 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 21:14:31 +0800 Subject: [PATCH 0440/1089] more --- python/sglang/srt/managers/eplb_manager.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index f63d2d866e2..affaa8c797c 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -18,16 +18,12 @@ class EPLBManager: def __init__(self, server_args: ServerArgs): super().__init__() self._server_args = server_args - - def initialize(self, tokenizer_manager: TokenizerManager): self._expert_distribution_storage = ExpertDistributionStorage( - dir_data=Path(self._server_args.eplb_storage_dir) / "expert_distribution_storage", - tokenizer_manager=tokenizer_manager, - ) + dir_data=Path(self._server_args.eplb_storage_dir) / "expert_distribution_storage") async def handle_loop(self): await self._expert_distribution_storage.initialize() - TODO + # TODO def compute_expert_location_metadata(self): logical_count = self._expert_distribution_storage.get_last_snapshot()["logical_count"] From aae2976035d70042346f7f84934d17ce5a095c55 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 21:15:50 +0800 Subject: [PATCH 0441/1089] more --- python/sglang/srt/managers/eplb_manager.py | 3 +++ python/sglang/srt/managers/expert_distribution_storage.py | 6 ++++-- python/sglang/srt/managers/tokenizer_manager.py | 2 +- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index affaa8c797c..0f9ed93034a 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -20,6 +20,9 @@ def __init__(self, server_args: ServerArgs): self._server_args = server_args self._expert_distribution_storage = ExpertDistributionStorage( dir_data=Path(self._server_args.eplb_storage_dir) / "expert_distribution_storage") + + def bind(self, tokenizer_manager: TokenizerManager): + self._expert_distribution_storage.bind(tokenizer_manager) async def handle_loop(self): await self._expert_distribution_storage.initialize() diff --git a/python/sglang/srt/managers/expert_distribution_storage.py b/python/sglang/srt/managers/expert_distribution_storage.py index 23da23b7a0a..8178e914faf 100644 --- a/python/sglang/srt/managers/expert_distribution_storage.py +++ b/python/sglang/srt/managers/expert_distribution_storage.py @@ -10,12 +10,14 @@ class ExpertDistributionStorage: - def __init__(self, dir_data, tokenizer_manager: TokenizerManager): + def __init__(self, dir_data): self._dir_data = Path(dir_data) - self._tokenizer_manager = tokenizer_manager if not self._dir_data.exists(): self._dir_data.mkdir(parents=True, exist_ok=True) + def bind(self, tokenizer_manager: TokenizerManager): + self._tokenizer_manager = tokenizer_manager + async def initialize(self): await self._tokenizer_manager.start_expert_distribution_record() diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index d58055cdb45..9aa7d5c2a4b 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -203,7 +203,7 @@ def __init__( ) if eplb_manager is not None: - eplb_manager.initialize(self) + eplb_manager.bind(self) self.eplb_manager = eplb_manager # Store states From 686743a03947fa8d6011ab044500195d5c819416 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 21:16:07 +0800 Subject: [PATCH 0442/1089] more --- python/sglang/srt/managers/eplb_manager.py | 4 ++-- python/sglang/srt/managers/expert_distribution_storage.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index 0f9ed93034a..85d28afdb49 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -20,12 +20,12 @@ def __init__(self, server_args: ServerArgs): self._server_args = server_args self._expert_distribution_storage = ExpertDistributionStorage( dir_data=Path(self._server_args.eplb_storage_dir) / "expert_distribution_storage") - + def bind(self, tokenizer_manager: TokenizerManager): self._expert_distribution_storage.bind(tokenizer_manager) async def handle_loop(self): - await self._expert_distribution_storage.initialize() + await self._expert_distribution_storage.start() # TODO def compute_expert_location_metadata(self): diff --git a/python/sglang/srt/managers/expert_distribution_storage.py b/python/sglang/srt/managers/expert_distribution_storage.py index 8178e914faf..f544130f63a 100644 --- a/python/sglang/srt/managers/expert_distribution_storage.py +++ b/python/sglang/srt/managers/expert_distribution_storage.py @@ -18,7 +18,7 @@ def __init__(self, dir_data): def bind(self, tokenizer_manager: TokenizerManager): self._tokenizer_manager = tokenizer_manager - async def initialize(self): + async def start(self): await self._tokenizer_manager.start_expert_distribution_record() async def save_current(self): From 1cab4e422d4399f4e79b83825f5b2a09728444c0 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 21:16:25 +0800 Subject: [PATCH 0443/1089] more --- python/sglang/srt/managers/eplb_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index 85d28afdb49..47cf62553f2 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -26,7 +26,7 @@ def bind(self, tokenizer_manager: TokenizerManager): async def handle_loop(self): await self._expert_distribution_storage.start() - # TODO + # TODO auto call rebalance, etc, when Engine supports that def compute_expert_location_metadata(self): logical_count = self._expert_distribution_storage.get_last_snapshot()["logical_count"] From 5ba2b94f0125b08a4259ef784dcbaef3ac6c85c7 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 21:17:23 +0800 Subject: [PATCH 0444/1089] more --- python/sglang/srt/managers/tokenizer_manager.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 9aa7d5c2a4b..565d99a78fc 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -202,9 +202,9 @@ def __init__( revision=server_args.revision, ) + self.eplb_manager = eplb_manager if eplb_manager is not None: eplb_manager.bind(self) - self.eplb_manager = eplb_manager # Store states self.no_create_loop = False @@ -894,6 +894,11 @@ def auto_create_handle_loop(self): loop.create_task(print_exception_wrapper(self.sigterm_watchdog)) ) + if self.eplb_manager is not None: + self.asyncio_tasks.add( + loop.create_task(print_exception_wrapper(self.eplb_manager.handle_loop)) + ) + async def sigterm_watchdog(self): while not self.gracefully_exit: await asyncio.sleep(5) From 92e26a6997ee002535ac5d2301cd653edb3eb6e6 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 21:18:36 +0800 Subject: [PATCH 0445/1089] fmt --- python/sglang/srt/entrypoints/engine.py | 10 ++- python/sglang/srt/managers/eplb_manager.py | 12 +++- .../managers/expert_distribution_storage.py | 2 +- python/sglang/srt/managers/expert_location.py | 59 +++++++++++----- .../sglang/srt/managers/tokenizer_manager.py | 5 +- python/sglang/srt/server_args.py | 70 +++++++++---------- 6 files changed, 97 insertions(+), 61 deletions(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 31498932558..72689b2a5f3 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -31,6 +31,7 @@ import zmq import zmq.asyncio from PIL.Image import Image + from sglang.srt.configs.model_config import ModelConfig from sglang.srt.managers.eplb_manager import EPLBManager from sglang.srt.managers.expert_location import ExpertLocationMetadata @@ -500,7 +501,9 @@ def _launch_subprocesses( ) eplb_manager = EPLBManager(server_args) if server_args.enable_eplb else None - expert_location_metadata = _compute_initial_expert_location_metadata(server_args, eplb_manager) + expert_location_metadata = _compute_initial_expert_location_metadata( + server_args, eplb_manager + ) scheduler_procs = [] if server_args.dp_size == 1: @@ -616,8 +619,9 @@ def _launch_subprocesses( return tokenizer_manager, scheduler_info -def _compute_initial_expert_location_metadata(server_args: ServerArgs, - eplb_manager: EPLBManager) -> ExpertLocationMetadata: +def _compute_initial_expert_location_metadata( + server_args: ServerArgs, eplb_manager: EPLBManager +) -> ExpertLocationMetadata: if (data := server_args.init_expert_location) is not None: # TODO We may want to allow users to not provide `logical_to_all_physical_map` if this API is frequently used return ExpertLocationMetadata.init_by_mapping(server_args, **json.loads(data)) diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index 47cf62553f2..8490bdcaaf1 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -19,7 +19,9 @@ def __init__(self, server_args: ServerArgs): super().__init__() self._server_args = server_args self._expert_distribution_storage = ExpertDistributionStorage( - dir_data=Path(self._server_args.eplb_storage_dir) / "expert_distribution_storage") + dir_data=Path(self._server_args.eplb_storage_dir) + / "expert_distribution_storage" + ) def bind(self, tokenizer_manager: TokenizerManager): self._expert_distribution_storage.bind(tokenizer_manager) @@ -29,10 +31,14 @@ async def handle_loop(self): # TODO auto call rebalance, etc, when Engine supports that def compute_expert_location_metadata(self): - logical_count = self._expert_distribution_storage.get_last_snapshot()["logical_count"] + logical_count = self._expert_distribution_storage.get_last_snapshot()[ + "logical_count" + ] if logical_count is None: return ExpertLocationMetadata.init_trivial(self._server_args) - return ExpertLocationMetadata.init_by_eplb(self._server_args, logical_count=logical_count) + return ExpertLocationMetadata.init_by_eplb( + self._server_args, logical_count=logical_count + ) class _EPLBManagerNoop(EPLBManager): diff --git a/python/sglang/srt/managers/expert_distribution_storage.py b/python/sglang/srt/managers/expert_distribution_storage.py index f544130f63a..9c9bffebf24 100644 --- a/python/sglang/srt/managers/expert_distribution_storage.py +++ b/python/sglang/srt/managers/expert_distribution_storage.py @@ -2,7 +2,7 @@ import logging import time from pathlib import Path -from typing import Any, Optional, Dict +from typing import Any, Dict, Optional from sglang.srt.managers.tokenizer_manager import TokenizerManager diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 6095e3549c3..273fbe6230a 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -3,6 +3,7 @@ from typing import List, Optional import torch + from sglang.srt.configs.model_config import ModelConfig from sglang.srt.managers import deepseek_eplb from sglang.srt.managers.schedule_batch import global_server_args_dict @@ -31,15 +32,25 @@ def init_trivial(server_args: ServerArgs): num_layers = model_config_for_expert_location.num_layers num_logical_experts = model_config_for_expert_location.num_logical_experts - physical_to_logical_map = torch.arange(0, num_physical_experts).repeat(num_layers, 1) % num_logical_experts + physical_to_logical_map = ( + torch.arange(0, num_physical_experts).repeat(num_layers, 1) + % num_logical_experts + ) # Highly inefficient, but we do not care since we will use EPLB distribution logic - logical_to_all_physical_map = torch.arange(0, num_logical_experts).repeat(num_layers, 1)[..., None] + logical_to_all_physical_map = torch.arange(0, num_logical_experts).repeat( + num_layers, 1 + )[..., None] - return ExpertLocationMetadata.init_by_mapping(server_args, physical_to_logical_map=physical_to_logical_map, - logical_to_all_physical_map=logical_to_all_physical_map) + return ExpertLocationMetadata.init_by_mapping( + server_args, + physical_to_logical_map=physical_to_logical_map, + logical_to_all_physical_map=logical_to_all_physical_map, + ) @staticmethod - def init_by_mapping(server_args: ServerArgs, physical_to_logical_map, logical_to_all_physical_map): + def init_by_mapping( + server_args: ServerArgs, physical_to_logical_map, logical_to_all_physical_map + ): if not isinstance(physical_to_logical_map, torch.Tensor): physical_to_logical_map = torch.tensor(physical_to_logical_map) @@ -53,7 +64,8 @@ def init_by_mapping(server_args: ServerArgs, physical_to_logical_map, logical_to physical_to_logical_map=physical_to_logical_map, logical_to_all_physical_map=logical_to_all_physical_map, logical_to_rank_dispatch_physical_map=_compute_logical_to_rank_dispatch_physical_map( - logical_to_all_physical_map, num_gpus=common["world_size"], + logical_to_all_physical_map, + num_gpus=common["world_size"], ), ) @@ -79,16 +91,22 @@ def init_by_eplb(server_args: ServerArgs, logical_count: torch.Tensor): physical_to_logical_map=physical_to_logical_map, logical_to_all_physical_map=logical_to_all_physical_map, logical_to_rank_dispatch_physical_map=_compute_logical_to_rank_dispatch_physical_map( - logical_to_all_physical_map, num_gpus=common["world_size"], + logical_to_all_physical_map, + num_gpus=common["world_size"], ), ) @staticmethod def _init_common(server_args: ServerArgs): model_config = ModelConfig.from_server_args(server_args) - model_config_for_expert_location = ModelConfigForExpertLocation.from_model_config(model_config) + model_config_for_expert_location = ( + ModelConfigForExpertLocation.from_model_config(model_config) + ) - num_physical_experts = model_config_for_expert_location.num_logical_experts + server_args.ep_num_redundant_experts + num_physical_experts = ( + model_config_for_expert_location.num_logical_experts + + server_args.ep_num_redundant_experts + ) # TODO consider case when DP attention is disabled and DP > 1 world_size = server_args.tp_size assert num_physical_experts % world_size == 0 @@ -110,13 +128,15 @@ def physical_to_local_physical(self, global_physical_expert_index: int): return global_physical_expert_index % self.num_local_physical_experts def logical_to_all_physical( - self, layer_id: int, logical_expert_id: int + self, layer_id: int, logical_expert_id: int ) -> List[int]: - return self.logical_to_all_physical_raw(self.logical_to_all_physical_map, layer_id, logical_expert_id) + return self.logical_to_all_physical_raw( + self.logical_to_all_physical_map, layer_id, logical_expert_id + ) @staticmethod def logical_to_all_physical_raw( - logical_to_all_physical_map, layer_id: int, logical_expert_id: int + logical_to_all_physical_map, layer_id: int, logical_expert_id: int ) -> List[int]: return [ physical_expert_id @@ -128,8 +148,8 @@ def logical_to_all_physical_raw( def _compute_logical_to_rank_dispatch_physical_map( - logical_to_all_physical_map: torch.Tensor, - num_gpus: int, + logical_to_all_physical_map: torch.Tensor, + num_gpus: int, ): # TODO maybe improve this algorithm (e.g. ensure it is really balanced) # This is rarely called, so we use for loops for maximum clarity @@ -137,14 +157,19 @@ def _compute_logical_to_rank_dispatch_physical_map( r = random.Random() num_layers, num_logical_experts, _ = logical_to_all_physical_map.shape - logical_to_rank_dispatch_physical_map = torch.zeros((num_gpus, num_layers, num_logical_experts)) + logical_to_rank_dispatch_physical_map = torch.zeros( + (num_gpus, num_layers, num_logical_experts) + ) for layer_id in range(num_layers): for logical_expert_id in range(num_logical_experts): for gpu_id in range(num_gpus): candidate_values = ExpertLocationMetadata.logical_to_all_physical_raw( - logical_to_all_physical_map, layer_id, logical_expert_id) - logical_to_rank_dispatch_physical_map[gpu_id, layer_id, logical_expert_id] = r.choice(candidate_values) + logical_to_all_physical_map, layer_id, logical_expert_id + ) + logical_to_rank_dispatch_physical_map[ + gpu_id, layer_id, logical_expert_id + ] = r.choice(candidate_values) return logical_to_rank_dispatch_physical_map diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 565d99a78fc..038614017d2 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -45,6 +45,7 @@ import zmq import zmq.asyncio from fastapi import BackgroundTasks + from sglang.srt.aio_rwlock import RWLock from sglang.srt.configs.model_config import ModelConfig from sglang.srt.disaggregation.conn import KVBootstrapServer @@ -972,8 +973,8 @@ def _handle_batch_output( elif isinstance(recv_obj, BatchTokenIDOut): if self.server_args.stream_output and state.obj.stream: output_token_ids = recv_obj.output_ids[i][ - state.last_output_offset: - ] + state.last_output_offset : + ] state.last_output_offset = len(recv_obj.output_ids[i]) else: output_token_ids = recv_obj.output_ids[i] diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index bc302468bd7..0978c56fce4 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -423,8 +423,8 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.tokenizer_mode, choices=["auto", "slow"], help="Tokenizer mode. 'auto' will use the fast " - "tokenizer if available, and 'slow' will " - "always use the slow tokenizer.", + "tokenizer if available, and 'slow' will " + "always use the slow tokenizer.", ) parser.add_argument( "--skip-tokenizer-init", @@ -448,21 +448,21 @@ def add_cli_args(parser: argparse.ArgumentParser): "remote", ], help="The format of the model weights to load. " - '"auto" will try to load the weights in the safetensors format ' - "and fall back to the pytorch bin format if safetensors format " - "is not available. " - '"pt" will load the weights in the pytorch bin format. ' - '"safetensors" will load the weights in the safetensors format. ' - '"npcache" will load the weights in pytorch format and store ' - "a numpy cache to speed up the loading. " - '"dummy" will initialize the weights with random values, ' - "which is mainly for profiling." - '"gguf" will load the weights in the gguf format. ' - '"bitsandbytes" will load the weights using bitsandbytes ' - "quantization." - '"layered" loads weights layer by layer so that one can quantize a ' - "layer before loading another to make the peak memory envelope " - "smaller.", + '"auto" will try to load the weights in the safetensors format ' + "and fall back to the pytorch bin format if safetensors format " + "is not available. " + '"pt" will load the weights in the pytorch bin format. ' + '"safetensors" will load the weights in the safetensors format. ' + '"npcache" will load the weights in pytorch format and store ' + "a numpy cache to speed up the loading. " + '"dummy" will initialize the weights with random values, ' + "which is mainly for profiling." + '"gguf" will load the weights in the gguf format. ' + '"bitsandbytes" will load the weights using bitsandbytes ' + "quantization." + '"layered" loads weights layer by layer so that one can quantize a ' + "layer before loading another to make the peak memory envelope " + "smaller.", ) parser.add_argument( "--trust-remote-code", @@ -475,13 +475,13 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.dtype, choices=["auto", "half", "float16", "bfloat16", "float", "float32"], help="Data type for model weights and activations.\n\n" - '* "auto" will use FP16 precision for FP32 and FP16 models, and ' - "BF16 precision for BF16 models.\n" - '* "half" for FP16. Recommended for AWQ quantization.\n' - '* "float16" is the same as "half".\n' - '* "bfloat16" for a balance between precision and range.\n' - '* "float" is shorthand for FP32 precision.\n' - '* "float32" for FP32 precision.', + '* "auto" will use FP16 precision for FP32 and FP16 models, and ' + "BF16 precision for BF16 models.\n" + '* "half" for FP16. Recommended for AWQ quantization.\n' + '* "float16" is the same as "half".\n' + '* "bfloat16" for a balance between precision and range.\n' + '* "float" is shorthand for FP32 precision.\n' + '* "float32" for FP32 precision.', ) parser.add_argument( "--kv-cache-dtype", @@ -516,9 +516,9 @@ def add_cli_args(parser: argparse.ArgumentParser): type=nullable_str, default=None, help="Path to the JSON file containing the KV cache " - "scaling factors. This should generally be supplied, when " - "KV cache dtype is FP8. Otherwise, KV cache scaling factors " - "default to 1.0, which may cause accuracy issues. ", + "scaling factors. This should generally be supplied, when " + "KV cache dtype is FP8. Otherwise, KV cache scaling factors " + "default to 1.0, which may cause accuracy issues. ", ) parser.add_argument( "--context-length", @@ -560,8 +560,8 @@ def add_cli_args(parser: argparse.ArgumentParser): type=str, default=None, help="The specific model version to use. It can be a branch " - "name, a tag name, or a commit id. If unspecified, will use " - "the default version.", + "name, a tag name, or a commit id. If unspecified, will use " + "the default version.", ) # Memory and scheduling parser.add_argument( @@ -581,7 +581,7 @@ def add_cli_args(parser: argparse.ArgumentParser): type=int, default=ServerArgs.max_total_tokens, help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. " - "This option is typically used for development and debugging purposes.", + "This option is typically used for development and debugging purposes.", ) parser.add_argument( "--chunked-prefill-size", @@ -1046,7 +1046,7 @@ def add_cli_args(parser: argparse.ArgumentParser): "--triton-attention-reduce-in-fp32", action="store_true", help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16." - "This only affects Triton attention kernels.", + "This only affects Triton attention kernels.", ) parser.add_argument( "--triton-attention-num-kv-splits", @@ -1059,8 +1059,8 @@ def add_cli_args(parser: argparse.ArgumentParser): type=int, default=ServerArgs.num_continuous_decode_steps, help="Run multiple continuous decoding steps to reduce scheduling overhead. " - "This can potentially increase throughput but may also increase time-to-first-token latency. " - "The default value is 1, meaning only run one decoding step at a time.", + "This can potentially increase throughput but may also increase time-to-first-token latency. " + "The default value is 1, meaning only run one decoding step at a time.", ) parser.add_argument( "--delete-ckpt-after-loading", @@ -1142,7 +1142,7 @@ def add_cli_args(parser: argparse.ArgumentParser): type=int, default=0, help="The number of shared_experts need to be replica to fuse with normal experts in deepseek v3/r1 " - "we use tp_size by default.", + "we use tp_size by default.", ) parser.add_argument( "--disable-shared-experts-fusion", @@ -1156,7 +1156,7 @@ def add_cli_args(parser: argparse.ArgumentParser): type=str, required=False, help="Specify custom warmup functions (csv) to run before server starts eg. --warmups=warmup_name1,warmup_name2 " - "will run the functions `warmup_name1` and `warmup_name2` specified in warmup.py before the server starts listening for requests", + "will run the functions `warmup_name1` and `warmup_name2` specified in warmup.py before the server starts listening for requests", ) # Debug tensor dumps From 92069e58f4c2dd4f0ca84f8450b61c2abc09df9d Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 21:20:14 +0800 Subject: [PATCH 0446/1089] more --- python/sglang/srt/managers/eplb_manager.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index 8490bdcaaf1..ccbcaca6c19 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -23,7 +23,7 @@ def __init__(self, server_args: ServerArgs): / "expert_distribution_storage" ) - def bind(self, tokenizer_manager: TokenizerManager): + def bind(self, tokenizer_manager: "TokenizerManager"): self._expert_distribution_storage.bind(tokenizer_manager) async def handle_loop(self): @@ -39,7 +39,3 @@ def compute_expert_location_metadata(self): return ExpertLocationMetadata.init_by_eplb( self._server_args, logical_count=logical_count ) - - -class _EPLBManagerNoop(EPLBManager): - pass From 50cdc3dd17b66e00cea4bd71167fb05ffa1d8944 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 21:21:57 +0800 Subject: [PATCH 0447/1089] rm --- python/sglang/srt/entrypoints/engine.py | 12 +- python/sglang/srt/managers/deepseek_eplb.py | 220 ------------------ python/sglang/srt/managers/eplb_manager.py | 41 ---- .../managers/expert_distribution_storage.py | 36 --- python/sglang/srt/managers/expert_location.py | 28 --- .../sglang/srt/managers/tokenizer_manager.py | 11 - python/sglang/srt/server_args.py | 13 -- 7 files changed, 3 insertions(+), 358 deletions(-) delete mode 100644 python/sglang/srt/managers/deepseek_eplb.py delete mode 100644 python/sglang/srt/managers/eplb_manager.py delete mode 100644 python/sglang/srt/managers/expert_distribution_storage.py diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 72689b2a5f3..61137b70c63 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -33,7 +33,6 @@ from PIL.Image import Image from sglang.srt.configs.model_config import ModelConfig -from sglang.srt.managers.eplb_manager import EPLBManager from sglang.srt.managers.expert_location import ExpertLocationMetadata # Fix a bug of Python threading @@ -500,10 +499,7 @@ def _launch_subprocesses( server_args.model_path, server_args.tokenizer_path ) - eplb_manager = EPLBManager(server_args) if server_args.enable_eplb else None - expert_location_metadata = _compute_initial_expert_location_metadata( - server_args, eplb_manager - ) + expert_location_metadata = _compute_initial_expert_location_metadata(server_args) scheduler_procs = [] if server_args.dp_size == 1: @@ -584,7 +580,7 @@ def _launch_subprocesses( # Launch tokenizer process tokenizer_manager = TokenizerManager( - server_args, port_args, expert_location_metadata, eplb_manager + server_args, port_args, expert_location_metadata ) if server_args.chat_template: load_chat_template_for_openai_api( @@ -620,11 +616,9 @@ def _launch_subprocesses( def _compute_initial_expert_location_metadata( - server_args: ServerArgs, eplb_manager: EPLBManager + server_args: ServerArgs, ) -> ExpertLocationMetadata: if (data := server_args.init_expert_location) is not None: # TODO We may want to allow users to not provide `logical_to_all_physical_map` if this API is frequently used return ExpertLocationMetadata.init_by_mapping(server_args, **json.loads(data)) - if server_args.enable_eplb: - return eplb_manager.compute_expert_location_metadata() return ExpertLocationMetadata.init_trivial(server_args) diff --git a/python/sglang/srt/managers/deepseek_eplb.py b/python/sglang/srt/managers/deepseek_eplb.py deleted file mode 100644 index e5875b21342..00000000000 --- a/python/sglang/srt/managers/deepseek_eplb.py +++ /dev/null @@ -1,220 +0,0 @@ -# This file is copied from https://github.com/deepseek-ai/EPLB/blob/main/eplb.py since that one is not a pypi package - -from typing import Tuple - -import torch - - -def balanced_packing( - weight: torch.Tensor, num_packs: int -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Pack n weighted objects to m packs, such that each bin contains exactly n/m objects and the weights of all packs - are as balanced as possible. - - Parameters: - weight: [X, n], the weight of each item - num_packs: number of packs - - Returns: - pack_index: [X, n], the pack index of each item - rank_in_pack: [X, n], the rank of the item in the pack - """ - num_layers, num_groups = weight.shape - assert num_groups % num_packs == 0 - groups_per_pack = num_groups // num_packs - - if groups_per_pack == 1: - pack_index = torch.arange( - weight.size(-1), dtype=torch.int64, device=weight.device - ).expand(weight.shape) - rank_in_pack = torch.zeros_like(weight, dtype=torch.int64) - return pack_index, rank_in_pack - - indices = weight.float().sort(-1, descending=True).indices.cpu() - pack_index = torch.full_like(weight, fill_value=-1, dtype=torch.int64, device="cpu") - rank_in_pack = torch.full_like(pack_index, fill_value=-1) - for i in range(num_layers): - pack_weights = [0] * num_packs - pack_items = [0] * num_packs - for group in indices[i]: - pack = min( - (i for i in range(num_packs) if pack_items[i] < groups_per_pack), - key=pack_weights.__getitem__, - ) - assert pack_items[pack] < groups_per_pack - pack_index[i, group] = pack - rank_in_pack[i, group] = pack_items[pack] - pack_weights[pack] += weight[i, group] - pack_items[pack] += 1 - return pack_index, rank_in_pack - - -def replicate_experts( - weight: torch.Tensor, num_phy: int -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Replicate `num_log` experts to `num_phy` replicas, such that the maximum load of all replicas is minimized. - - Parameters: - weight: [X, num_log] - num_phy: total number of experts after replication - - Returns: - phy2log: [X, num_phy], logical expert id of each physical expert - rank: [X, num_phy], the replica rank - logcnt: [X, num_log], number of replicas for each logical expert - """ - n, num_log = weight.shape - num_redundant = num_phy - num_log - assert num_redundant >= 0 - device = weight.device - phy2log = torch.arange(num_phy, dtype=torch.int64, device=device).repeat(n, 1) - rank = torch.zeros(n, num_phy, dtype=torch.int64, device=device) - logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device) - arangen = torch.arange(n, dtype=torch.int64, device=device) - for i in range(num_log, num_phy): - redundant_indices = (weight / logcnt).max(dim=-1).indices - phy2log[:, i] = redundant_indices - rank[:, i] = logcnt[arangen, redundant_indices] - logcnt[arangen, redundant_indices] += 1 - return phy2log, rank, logcnt - - -def rebalance_experts_hierarchical( - weight: torch.Tensor, - num_physical_experts: int, - num_groups: int, - num_nodes: int, - num_gpus: int, -): - """ - Parameters: - weight: [num_moe_layers, num_logical_experts] - num_physical_experts: number of physical experts after replication - num_groups: number of expert groups - num_nodes: number of server nodes, where the intra-node network (e.g, NVLink) is faster - num_gpus: number of GPUs, must be a multiple of `num_nodes` - - Returns: - physical_to_logical_map: [num_moe_layers, num_physical_experts] - logical_to_physical_map: [num_moe_layers, num_logical_experts, X] - logical_count: [num_moe_layers, num_logical_experts] - """ - num_layers, num_logical_experts = weight.shape - assert num_logical_experts % num_groups == 0 - group_size = num_logical_experts // num_groups - assert num_groups % num_nodes == 0 - groups_per_node = num_groups // num_nodes - assert num_gpus % num_nodes == 0 - assert num_physical_experts % num_gpus == 0 - phy_experts_per_gpu = num_physical_experts // num_gpus - - def inverse(perm: torch.Tensor) -> torch.Tensor: - inv = torch.empty_like(perm) - inv.scatter_( - 1, - perm, - torch.arange(perm.size(1), dtype=torch.int64, device=perm.device).expand( - perm.shape - ), - ) - return inv - - # Step 1: pack groups to nodes - tokens_per_group = weight.unflatten(-1, (num_groups, group_size)).sum(-1) - group_pack_index, group_rank_in_pack = balanced_packing(tokens_per_group, num_nodes) - log2mlog = ( - ( - (group_pack_index * groups_per_node + group_rank_in_pack) * group_size - ).unsqueeze(-1) - + torch.arange(group_size, dtype=torch.int64, device=group_pack_index.device) - ).flatten(-2) - mlog2log = inverse(log2mlog) - - # Step 2: construct redundant experts within nodes - # [num_layers * num_nodes, num_logical_experts // num_nodes] - tokens_per_mlog = weight.gather(-1, mlog2log).view( - -1, num_logical_experts // num_nodes - ) - phy2mlog, phyrank, mlogcnt = replicate_experts( - tokens_per_mlog, num_physical_experts // num_nodes - ) - - # Step 3: pack physical_experts to GPUs - # [num_layers * num_nodes, num_physical_experts // num_nodes] - tokens_per_phy = (tokens_per_mlog / mlogcnt).gather(-1, phy2mlog) - pack_index, rank_in_pack = balanced_packing(tokens_per_phy, num_gpus // num_nodes) - phy2pphy = pack_index * phy_experts_per_gpu + rank_in_pack - pphy2phy = inverse(phy2pphy) - - pphy2mlog = phy2mlog.gather( - -1, pphy2phy - ) # [num_layers * num_nodes, num_log_per_nodes] - pphy2mlog = ( - pphy2mlog.view(num_layers, num_nodes, -1) - + torch.arange( - 0, - num_logical_experts, - num_logical_experts // num_nodes, - device=group_pack_index.device, - ).view(1, -1, 1) - ).flatten(-2) - pphy2log = mlog2log.gather(-1, pphy2mlog) - pphyrank = phyrank.gather(-1, pphy2phy).view(num_layers, -1) - logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog) - return pphy2log, pphyrank, logcnt - - -def rebalance_experts( - weight: torch.Tensor, - num_replicas: int, - num_groups: int, - num_nodes: int, - num_gpus: int, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Entry point for expert-parallelism load balancer. - - Parameters: - weight: [layers, num_logical_experts], the load statistics for all logical experts - num_replicas: number of physical experts, must be a multiple of `num_gpus` - num_groups: number of expert groups - num_nodes: number of server nodes, where the intra-node network (e.g, NVLink) is faster - num_gpus: number of GPUs, must be a multiple of `num_nodes` - - Returns: - physical_to_logical_map: [layers, num_replicas], the expert index of each replica - logical_to_physical_map: [layers, num_logical_experts, X], the replica indices for each expert - expert_count: [layers, num_logical_experts], number of physical replicas for each logical expert - """ - num_layers, num_logical_experts = weight.shape - weight = weight.float().cpu() - if num_groups % num_nodes == 0: - # use hierarchical load-balance policy - phy2log, phyrank, logcnt = rebalance_experts_hierarchical( - weight, num_replicas, num_groups, num_nodes, num_gpus - ) - else: - # use global load-balance policy - phy2log, phyrank, logcnt = rebalance_experts_hierarchical( - weight, num_replicas, 1, 1, num_gpus - ) - maxlogcnt = logcnt.max().item() - log2phy: torch.Tensor = torch.full( - (num_layers, num_logical_experts, maxlogcnt), - -1, - dtype=torch.int64, - device=logcnt.device, - ) - log2phy.view(num_layers, -1).scatter_( - -1, - phy2log * maxlogcnt + phyrank, - torch.arange(num_replicas, dtype=torch.int64, device=log2phy.device).expand( - num_layers, -1 - ), - ) - return phy2log, log2phy, logcnt - - -__all__ = ["rebalance_experts"] diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py deleted file mode 100644 index ccbcaca6c19..00000000000 --- a/python/sglang/srt/managers/eplb_manager.py +++ /dev/null @@ -1,41 +0,0 @@ -from pathlib import Path -from typing import TYPE_CHECKING - -from sglang.srt.configs.model_config import ModelConfig -from sglang.srt.managers import deepseek_eplb -from sglang.srt.managers.expert_distribution_storage import ExpertDistributionStorage -from sglang.srt.managers.expert_location import ( - ExpertLocationMetadata, - ModelConfigForExpertLocation, -) -from sglang.srt.server_args import ServerArgs - -if TYPE_CHECKING: - from sglang.srt.managers.tokenizer_manager import TokenizerManager - - -class EPLBManager: - def __init__(self, server_args: ServerArgs): - super().__init__() - self._server_args = server_args - self._expert_distribution_storage = ExpertDistributionStorage( - dir_data=Path(self._server_args.eplb_storage_dir) - / "expert_distribution_storage" - ) - - def bind(self, tokenizer_manager: "TokenizerManager"): - self._expert_distribution_storage.bind(tokenizer_manager) - - async def handle_loop(self): - await self._expert_distribution_storage.start() - # TODO auto call rebalance, etc, when Engine supports that - - def compute_expert_location_metadata(self): - logical_count = self._expert_distribution_storage.get_last_snapshot()[ - "logical_count" - ] - if logical_count is None: - return ExpertLocationMetadata.init_trivial(self._server_args) - return ExpertLocationMetadata.init_by_eplb( - self._server_args, logical_count=logical_count - ) diff --git a/python/sglang/srt/managers/expert_distribution_storage.py b/python/sglang/srt/managers/expert_distribution_storage.py deleted file mode 100644 index 9c9bffebf24..00000000000 --- a/python/sglang/srt/managers/expert_distribution_storage.py +++ /dev/null @@ -1,36 +0,0 @@ -import json -import logging -import time -from pathlib import Path -from typing import Any, Dict, Optional - -from sglang.srt.managers.tokenizer_manager import TokenizerManager - -logger = logging.getLogger(__name__) - - -class ExpertDistributionStorage: - def __init__(self, dir_data): - self._dir_data = Path(dir_data) - if not self._dir_data.exists(): - self._dir_data.mkdir(parents=True, exist_ok=True) - - def bind(self, tokenizer_manager: TokenizerManager): - self._tokenizer_manager = tokenizer_manager - - async def start(self): - await self._tokenizer_manager.start_expert_distribution_record() - - async def save_current(self): - data = await self._tokenizer_manager.dump_expert_distribution_record() - path = self._dir_data / f"{time.time_ns()}.json" - logger.info(f"save_current to path {path}") - path.write_text(json.dumps(data)) - - def get_last_snapshot(self) -> Optional[Dict[str, Any]]: - paths = sorted(list(self._dir_data.glob("*.json")), key=lambda p: int(p.stem)) - if len(paths) == 0: - return None - path = paths[-1] - logger.info(f"get_last_snapshot choose path {path}") - return json.loads(path.read_text()) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 273fbe6230a..01bc529db62 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -5,7 +5,6 @@ import torch from sglang.srt.configs.model_config import ModelConfig -from sglang.srt.managers import deepseek_eplb from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_loader import get_model_architecture from sglang.srt.server_args import ServerArgs @@ -69,33 +68,6 @@ def init_by_mapping( ), ) - @staticmethod - def init_by_eplb(server_args: ServerArgs, logical_count: torch.Tensor): - common = ExpertLocationMetadata._init_common(server_args) - model_config_for_expert_location = common["model_config_for_expert_location"] - - physical_to_logical_map, logical_to_all_physical_map, expert_count = ( - deepseek_eplb.rebalance_experts( - weight=logical_count, - num_replicas=common["num_physical_experts"], - num_groups=model_config_for_expert_location.num_groups, - num_nodes=server_args.nnodes, - num_gpus=common["world_size"], - ) - ) - - return ExpertLocationMetadata( - num_layers=model_config_for_expert_location.num_layers, - num_logical_experts=model_config_for_expert_location.num_logical_experts, - num_local_physical_experts=common["num_local_physical_experts"], - physical_to_logical_map=physical_to_logical_map, - logical_to_all_physical_map=logical_to_all_physical_map, - logical_to_rank_dispatch_physical_map=_compute_logical_to_rank_dispatch_physical_map( - logical_to_all_physical_map, - num_gpus=common["world_size"], - ), - ) - @staticmethod def _init_common(server_args: ServerArgs): model_config = ModelConfig.from_server_args(server_args) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 038614017d2..93a9507bf19 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -52,7 +52,6 @@ from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.managers import expert_distribution -from sglang.srt.managers.eplb_manager import EPLBManager from sglang.srt.managers.expert_location import ExpertLocationMetadata from sglang.srt.managers.io_struct import ( AbortReq, @@ -141,7 +140,6 @@ def __init__( server_args: ServerArgs, port_args: PortArgs, expert_location_metadata: ExpertLocationMetadata, - eplb_manager: Optional[EPLBManager], ): # Parse args self.server_args = server_args @@ -203,10 +201,6 @@ def __init__( revision=server_args.revision, ) - self.eplb_manager = eplb_manager - if eplb_manager is not None: - eplb_manager.bind(self) - # Store states self.no_create_loop = False self.rid_to_state: Dict[str, ReqState] = {} @@ -895,11 +889,6 @@ def auto_create_handle_loop(self): loop.create_task(print_exception_wrapper(self.sigterm_watchdog)) ) - if self.eplb_manager is not None: - self.asyncio_tasks.add( - loop.create_task(print_exception_wrapper(self.eplb_manager.handle_loop)) - ) - async def sigterm_watchdog(self): while not self.gracefully_exit: await asyncio.sleep(5) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 0978c56fce4..18efb46dd1b 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -164,8 +164,6 @@ class ServerArgs: deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto" ep_num_redundant_experts: int = 0 init_expert_location: Optional[str] = None - enable_eplb: bool = False - eplb_storage_dir: str = "/tmp/eplb_storage" enable_torch_compile: bool = False torch_compile_max_bs: int = 32 cuda_graph_max_bs: Optional[int] = None @@ -1118,17 +1116,6 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.init_expert_location, help="Initial location of EP experts.", ) - parser.add_argument( - "--enable-eplb", - action="store_true", - help="Enable EPLB algorithm", - ) - parser.add_argument( - "--eplb-cache-dir", - type=str, - default=ServerArgs.eplb_storage_dir, - help="Storage directory of EPLB subsystem.", - ) parser.add_argument( "--deepep-mode", type=str, From b9de9d1d4eadb78c516faef92f467834eab82c47 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 21:23:59 +0800 Subject: [PATCH 0448/1089] Revert "rm" This reverts commit 50cdc3dd17b66e00cea4bd71167fb05ffa1d8944. --- python/sglang/srt/entrypoints/engine.py | 12 +- python/sglang/srt/managers/deepseek_eplb.py | 220 ++++++++++++++++++ python/sglang/srt/managers/eplb_manager.py | 41 ++++ .../managers/expert_distribution_storage.py | 36 +++ python/sglang/srt/managers/expert_location.py | 28 +++ .../sglang/srt/managers/tokenizer_manager.py | 11 + python/sglang/srt/server_args.py | 13 ++ 7 files changed, 358 insertions(+), 3 deletions(-) create mode 100644 python/sglang/srt/managers/deepseek_eplb.py create mode 100644 python/sglang/srt/managers/eplb_manager.py create mode 100644 python/sglang/srt/managers/expert_distribution_storage.py diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 61137b70c63..72689b2a5f3 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -33,6 +33,7 @@ from PIL.Image import Image from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.managers.eplb_manager import EPLBManager from sglang.srt.managers.expert_location import ExpertLocationMetadata # Fix a bug of Python threading @@ -499,7 +500,10 @@ def _launch_subprocesses( server_args.model_path, server_args.tokenizer_path ) - expert_location_metadata = _compute_initial_expert_location_metadata(server_args) + eplb_manager = EPLBManager(server_args) if server_args.enable_eplb else None + expert_location_metadata = _compute_initial_expert_location_metadata( + server_args, eplb_manager + ) scheduler_procs = [] if server_args.dp_size == 1: @@ -580,7 +584,7 @@ def _launch_subprocesses( # Launch tokenizer process tokenizer_manager = TokenizerManager( - server_args, port_args, expert_location_metadata + server_args, port_args, expert_location_metadata, eplb_manager ) if server_args.chat_template: load_chat_template_for_openai_api( @@ -616,9 +620,11 @@ def _launch_subprocesses( def _compute_initial_expert_location_metadata( - server_args: ServerArgs, + server_args: ServerArgs, eplb_manager: EPLBManager ) -> ExpertLocationMetadata: if (data := server_args.init_expert_location) is not None: # TODO We may want to allow users to not provide `logical_to_all_physical_map` if this API is frequently used return ExpertLocationMetadata.init_by_mapping(server_args, **json.loads(data)) + if server_args.enable_eplb: + return eplb_manager.compute_expert_location_metadata() return ExpertLocationMetadata.init_trivial(server_args) diff --git a/python/sglang/srt/managers/deepseek_eplb.py b/python/sglang/srt/managers/deepseek_eplb.py new file mode 100644 index 00000000000..e5875b21342 --- /dev/null +++ b/python/sglang/srt/managers/deepseek_eplb.py @@ -0,0 +1,220 @@ +# This file is copied from https://github.com/deepseek-ai/EPLB/blob/main/eplb.py since that one is not a pypi package + +from typing import Tuple + +import torch + + +def balanced_packing( + weight: torch.Tensor, num_packs: int +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Pack n weighted objects to m packs, such that each bin contains exactly n/m objects and the weights of all packs + are as balanced as possible. + + Parameters: + weight: [X, n], the weight of each item + num_packs: number of packs + + Returns: + pack_index: [X, n], the pack index of each item + rank_in_pack: [X, n], the rank of the item in the pack + """ + num_layers, num_groups = weight.shape + assert num_groups % num_packs == 0 + groups_per_pack = num_groups // num_packs + + if groups_per_pack == 1: + pack_index = torch.arange( + weight.size(-1), dtype=torch.int64, device=weight.device + ).expand(weight.shape) + rank_in_pack = torch.zeros_like(weight, dtype=torch.int64) + return pack_index, rank_in_pack + + indices = weight.float().sort(-1, descending=True).indices.cpu() + pack_index = torch.full_like(weight, fill_value=-1, dtype=torch.int64, device="cpu") + rank_in_pack = torch.full_like(pack_index, fill_value=-1) + for i in range(num_layers): + pack_weights = [0] * num_packs + pack_items = [0] * num_packs + for group in indices[i]: + pack = min( + (i for i in range(num_packs) if pack_items[i] < groups_per_pack), + key=pack_weights.__getitem__, + ) + assert pack_items[pack] < groups_per_pack + pack_index[i, group] = pack + rank_in_pack[i, group] = pack_items[pack] + pack_weights[pack] += weight[i, group] + pack_items[pack] += 1 + return pack_index, rank_in_pack + + +def replicate_experts( + weight: torch.Tensor, num_phy: int +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Replicate `num_log` experts to `num_phy` replicas, such that the maximum load of all replicas is minimized. + + Parameters: + weight: [X, num_log] + num_phy: total number of experts after replication + + Returns: + phy2log: [X, num_phy], logical expert id of each physical expert + rank: [X, num_phy], the replica rank + logcnt: [X, num_log], number of replicas for each logical expert + """ + n, num_log = weight.shape + num_redundant = num_phy - num_log + assert num_redundant >= 0 + device = weight.device + phy2log = torch.arange(num_phy, dtype=torch.int64, device=device).repeat(n, 1) + rank = torch.zeros(n, num_phy, dtype=torch.int64, device=device) + logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device) + arangen = torch.arange(n, dtype=torch.int64, device=device) + for i in range(num_log, num_phy): + redundant_indices = (weight / logcnt).max(dim=-1).indices + phy2log[:, i] = redundant_indices + rank[:, i] = logcnt[arangen, redundant_indices] + logcnt[arangen, redundant_indices] += 1 + return phy2log, rank, logcnt + + +def rebalance_experts_hierarchical( + weight: torch.Tensor, + num_physical_experts: int, + num_groups: int, + num_nodes: int, + num_gpus: int, +): + """ + Parameters: + weight: [num_moe_layers, num_logical_experts] + num_physical_experts: number of physical experts after replication + num_groups: number of expert groups + num_nodes: number of server nodes, where the intra-node network (e.g, NVLink) is faster + num_gpus: number of GPUs, must be a multiple of `num_nodes` + + Returns: + physical_to_logical_map: [num_moe_layers, num_physical_experts] + logical_to_physical_map: [num_moe_layers, num_logical_experts, X] + logical_count: [num_moe_layers, num_logical_experts] + """ + num_layers, num_logical_experts = weight.shape + assert num_logical_experts % num_groups == 0 + group_size = num_logical_experts // num_groups + assert num_groups % num_nodes == 0 + groups_per_node = num_groups // num_nodes + assert num_gpus % num_nodes == 0 + assert num_physical_experts % num_gpus == 0 + phy_experts_per_gpu = num_physical_experts // num_gpus + + def inverse(perm: torch.Tensor) -> torch.Tensor: + inv = torch.empty_like(perm) + inv.scatter_( + 1, + perm, + torch.arange(perm.size(1), dtype=torch.int64, device=perm.device).expand( + perm.shape + ), + ) + return inv + + # Step 1: pack groups to nodes + tokens_per_group = weight.unflatten(-1, (num_groups, group_size)).sum(-1) + group_pack_index, group_rank_in_pack = balanced_packing(tokens_per_group, num_nodes) + log2mlog = ( + ( + (group_pack_index * groups_per_node + group_rank_in_pack) * group_size + ).unsqueeze(-1) + + torch.arange(group_size, dtype=torch.int64, device=group_pack_index.device) + ).flatten(-2) + mlog2log = inverse(log2mlog) + + # Step 2: construct redundant experts within nodes + # [num_layers * num_nodes, num_logical_experts // num_nodes] + tokens_per_mlog = weight.gather(-1, mlog2log).view( + -1, num_logical_experts // num_nodes + ) + phy2mlog, phyrank, mlogcnt = replicate_experts( + tokens_per_mlog, num_physical_experts // num_nodes + ) + + # Step 3: pack physical_experts to GPUs + # [num_layers * num_nodes, num_physical_experts // num_nodes] + tokens_per_phy = (tokens_per_mlog / mlogcnt).gather(-1, phy2mlog) + pack_index, rank_in_pack = balanced_packing(tokens_per_phy, num_gpus // num_nodes) + phy2pphy = pack_index * phy_experts_per_gpu + rank_in_pack + pphy2phy = inverse(phy2pphy) + + pphy2mlog = phy2mlog.gather( + -1, pphy2phy + ) # [num_layers * num_nodes, num_log_per_nodes] + pphy2mlog = ( + pphy2mlog.view(num_layers, num_nodes, -1) + + torch.arange( + 0, + num_logical_experts, + num_logical_experts // num_nodes, + device=group_pack_index.device, + ).view(1, -1, 1) + ).flatten(-2) + pphy2log = mlog2log.gather(-1, pphy2mlog) + pphyrank = phyrank.gather(-1, pphy2phy).view(num_layers, -1) + logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog) + return pphy2log, pphyrank, logcnt + + +def rebalance_experts( + weight: torch.Tensor, + num_replicas: int, + num_groups: int, + num_nodes: int, + num_gpus: int, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Entry point for expert-parallelism load balancer. + + Parameters: + weight: [layers, num_logical_experts], the load statistics for all logical experts + num_replicas: number of physical experts, must be a multiple of `num_gpus` + num_groups: number of expert groups + num_nodes: number of server nodes, where the intra-node network (e.g, NVLink) is faster + num_gpus: number of GPUs, must be a multiple of `num_nodes` + + Returns: + physical_to_logical_map: [layers, num_replicas], the expert index of each replica + logical_to_physical_map: [layers, num_logical_experts, X], the replica indices for each expert + expert_count: [layers, num_logical_experts], number of physical replicas for each logical expert + """ + num_layers, num_logical_experts = weight.shape + weight = weight.float().cpu() + if num_groups % num_nodes == 0: + # use hierarchical load-balance policy + phy2log, phyrank, logcnt = rebalance_experts_hierarchical( + weight, num_replicas, num_groups, num_nodes, num_gpus + ) + else: + # use global load-balance policy + phy2log, phyrank, logcnt = rebalance_experts_hierarchical( + weight, num_replicas, 1, 1, num_gpus + ) + maxlogcnt = logcnt.max().item() + log2phy: torch.Tensor = torch.full( + (num_layers, num_logical_experts, maxlogcnt), + -1, + dtype=torch.int64, + device=logcnt.device, + ) + log2phy.view(num_layers, -1).scatter_( + -1, + phy2log * maxlogcnt + phyrank, + torch.arange(num_replicas, dtype=torch.int64, device=log2phy.device).expand( + num_layers, -1 + ), + ) + return phy2log, log2phy, logcnt + + +__all__ = ["rebalance_experts"] diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py new file mode 100644 index 00000000000..ccbcaca6c19 --- /dev/null +++ b/python/sglang/srt/managers/eplb_manager.py @@ -0,0 +1,41 @@ +from pathlib import Path +from typing import TYPE_CHECKING + +from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.managers import deepseek_eplb +from sglang.srt.managers.expert_distribution_storage import ExpertDistributionStorage +from sglang.srt.managers.expert_location import ( + ExpertLocationMetadata, + ModelConfigForExpertLocation, +) +from sglang.srt.server_args import ServerArgs + +if TYPE_CHECKING: + from sglang.srt.managers.tokenizer_manager import TokenizerManager + + +class EPLBManager: + def __init__(self, server_args: ServerArgs): + super().__init__() + self._server_args = server_args + self._expert_distribution_storage = ExpertDistributionStorage( + dir_data=Path(self._server_args.eplb_storage_dir) + / "expert_distribution_storage" + ) + + def bind(self, tokenizer_manager: "TokenizerManager"): + self._expert_distribution_storage.bind(tokenizer_manager) + + async def handle_loop(self): + await self._expert_distribution_storage.start() + # TODO auto call rebalance, etc, when Engine supports that + + def compute_expert_location_metadata(self): + logical_count = self._expert_distribution_storage.get_last_snapshot()[ + "logical_count" + ] + if logical_count is None: + return ExpertLocationMetadata.init_trivial(self._server_args) + return ExpertLocationMetadata.init_by_eplb( + self._server_args, logical_count=logical_count + ) diff --git a/python/sglang/srt/managers/expert_distribution_storage.py b/python/sglang/srt/managers/expert_distribution_storage.py new file mode 100644 index 00000000000..9c9bffebf24 --- /dev/null +++ b/python/sglang/srt/managers/expert_distribution_storage.py @@ -0,0 +1,36 @@ +import json +import logging +import time +from pathlib import Path +from typing import Any, Dict, Optional + +from sglang.srt.managers.tokenizer_manager import TokenizerManager + +logger = logging.getLogger(__name__) + + +class ExpertDistributionStorage: + def __init__(self, dir_data): + self._dir_data = Path(dir_data) + if not self._dir_data.exists(): + self._dir_data.mkdir(parents=True, exist_ok=True) + + def bind(self, tokenizer_manager: TokenizerManager): + self._tokenizer_manager = tokenizer_manager + + async def start(self): + await self._tokenizer_manager.start_expert_distribution_record() + + async def save_current(self): + data = await self._tokenizer_manager.dump_expert_distribution_record() + path = self._dir_data / f"{time.time_ns()}.json" + logger.info(f"save_current to path {path}") + path.write_text(json.dumps(data)) + + def get_last_snapshot(self) -> Optional[Dict[str, Any]]: + paths = sorted(list(self._dir_data.glob("*.json")), key=lambda p: int(p.stem)) + if len(paths) == 0: + return None + path = paths[-1] + logger.info(f"get_last_snapshot choose path {path}") + return json.loads(path.read_text()) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 01bc529db62..273fbe6230a 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -5,6 +5,7 @@ import torch from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.managers import deepseek_eplb from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_loader import get_model_architecture from sglang.srt.server_args import ServerArgs @@ -68,6 +69,33 @@ def init_by_mapping( ), ) + @staticmethod + def init_by_eplb(server_args: ServerArgs, logical_count: torch.Tensor): + common = ExpertLocationMetadata._init_common(server_args) + model_config_for_expert_location = common["model_config_for_expert_location"] + + physical_to_logical_map, logical_to_all_physical_map, expert_count = ( + deepseek_eplb.rebalance_experts( + weight=logical_count, + num_replicas=common["num_physical_experts"], + num_groups=model_config_for_expert_location.num_groups, + num_nodes=server_args.nnodes, + num_gpus=common["world_size"], + ) + ) + + return ExpertLocationMetadata( + num_layers=model_config_for_expert_location.num_layers, + num_logical_experts=model_config_for_expert_location.num_logical_experts, + num_local_physical_experts=common["num_local_physical_experts"], + physical_to_logical_map=physical_to_logical_map, + logical_to_all_physical_map=logical_to_all_physical_map, + logical_to_rank_dispatch_physical_map=_compute_logical_to_rank_dispatch_physical_map( + logical_to_all_physical_map, + num_gpus=common["world_size"], + ), + ) + @staticmethod def _init_common(server_args: ServerArgs): model_config = ModelConfig.from_server_args(server_args) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 93a9507bf19..038614017d2 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -52,6 +52,7 @@ from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.managers import expert_distribution +from sglang.srt.managers.eplb_manager import EPLBManager from sglang.srt.managers.expert_location import ExpertLocationMetadata from sglang.srt.managers.io_struct import ( AbortReq, @@ -140,6 +141,7 @@ def __init__( server_args: ServerArgs, port_args: PortArgs, expert_location_metadata: ExpertLocationMetadata, + eplb_manager: Optional[EPLBManager], ): # Parse args self.server_args = server_args @@ -201,6 +203,10 @@ def __init__( revision=server_args.revision, ) + self.eplb_manager = eplb_manager + if eplb_manager is not None: + eplb_manager.bind(self) + # Store states self.no_create_loop = False self.rid_to_state: Dict[str, ReqState] = {} @@ -889,6 +895,11 @@ def auto_create_handle_loop(self): loop.create_task(print_exception_wrapper(self.sigterm_watchdog)) ) + if self.eplb_manager is not None: + self.asyncio_tasks.add( + loop.create_task(print_exception_wrapper(self.eplb_manager.handle_loop)) + ) + async def sigterm_watchdog(self): while not self.gracefully_exit: await asyncio.sleep(5) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 18efb46dd1b..0978c56fce4 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -164,6 +164,8 @@ class ServerArgs: deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto" ep_num_redundant_experts: int = 0 init_expert_location: Optional[str] = None + enable_eplb: bool = False + eplb_storage_dir: str = "/tmp/eplb_storage" enable_torch_compile: bool = False torch_compile_max_bs: int = 32 cuda_graph_max_bs: Optional[int] = None @@ -1116,6 +1118,17 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.init_expert_location, help="Initial location of EP experts.", ) + parser.add_argument( + "--enable-eplb", + action="store_true", + help="Enable EPLB algorithm", + ) + parser.add_argument( + "--eplb-cache-dir", + type=str, + default=ServerArgs.eplb_storage_dir, + help="Storage directory of EPLB subsystem.", + ) parser.add_argument( "--deepep-mode", type=str, From e910b3e56a2a8133f83e28cbe0610f4fd8098b47 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 21:24:11 +0800 Subject: [PATCH 0449/1089] more --- python/sglang/srt/server_args.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 0978c56fce4..0ee90dceee2 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1124,7 +1124,7 @@ def add_cli_args(parser: argparse.ArgumentParser): help="Enable EPLB algorithm", ) parser.add_argument( - "--eplb-cache-dir", + "--eplb-storage-dir", type=str, default=ServerArgs.eplb_storage_dir, help="Storage directory of EPLB subsystem.", From a13ad3a0fd51c9edfdb0b60fa9515a136fd6ca99 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 21:26:26 +0800 Subject: [PATCH 0450/1089] more --- python/sglang/srt/managers/io_struct.py | 2 ++ python/sglang/srt/managers/tp_worker.py | 2 +- python/sglang/srt/model_executor/model_runner.py | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 00affa0a4ed..5c472616282 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -675,6 +675,8 @@ class UpdateWeightFromDiskReqInput: model_path: str # The format to load the weights load_format: Optional[str] = None + # The parameter categories to filter + param_categories: Optional[List[str]] = None @dataclass diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 174f2e53321..498e52386e2 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -191,7 +191,7 @@ def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch): def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput): success, message = self.model_runner.update_weights_from_disk( - recv_req.model_path, recv_req.load_format + recv_req.model_path, recv_req.load_format, recv_req.param_categories ) return success, message diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 5663cafe1ab..afd9071f4ab 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -458,7 +458,7 @@ def load_model(self): ) from None def update_weights_from_disk( - self, model_path: str, load_format: str + self, model_path: str, load_format: str, param_categories: Optional[List[str]] ) -> tuple[bool, str]: """Update engine weights in-place from the disk.""" logger.info( From 3e151b998dc93ecf2a2927145e586432d5b04385 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 21:26:51 +0800 Subject: [PATCH 0451/1089] more --- python/sglang/srt/entrypoints/engine.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 33aab232fe3..dd136db848a 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -358,6 +358,7 @@ def update_weights_from_disk( self, model_path: str, load_format: Optional[str] = None, + param_categories: Optional[List[str]] = None, ): """Update the weights from disk inplace without re-launching the engine. @@ -368,6 +369,7 @@ def update_weights_from_disk( obj = UpdateWeightFromDiskReqInput( model_path=model_path, load_format=load_format, + param_categories=param_categories, ) loop = asyncio.get_event_loop() From 89a1a5e3353d3228332a55a34c7a9050f163eeba Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 21:27:44 +0800 Subject: [PATCH 0452/1089] more --- python/sglang/srt/model_executor/model_runner.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index afd9071f4ab..185ee338aea 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -488,13 +488,18 @@ def get_weight_iter(config): ) return iter + def filter_weight_iter(iter): + for name, weight in iter: + if TODO: + yield name, weight + def model_load_weights(model, iter): DefaultModelLoader.load_weights_and_postprocess(model, iter, target_device) return model with set_default_torch_dtype(self.model_config.dtype): try: - iter = get_weight_iter(self.model_config) + iter = filter_weight_iter(get_weight_iter(self.model_config)) except Exception as e: message = f"Failed to get weights iterator: {e}." return False, message From c351090845363bedb3636f39e4d33578c03a2593 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 21:28:07 +0800 Subject: [PATCH 0453/1089] more --- python/sglang/srt/model_executor/model_runner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 185ee338aea..d5c1cb3163f 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -20,7 +20,7 @@ import os import time from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union, Iterable import torch import torch.distributed as dist @@ -488,7 +488,7 @@ def get_weight_iter(config): ) return iter - def filter_weight_iter(iter): + def filter_weight_iter(iter: Iterable[Tuple[str, torch.Tensor]]): for name, weight in iter: if TODO: yield name, weight From 8690ff2dfdb8fefb0f485abca0101ebc1cc431e6 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 21:28:54 +0800 Subject: [PATCH 0454/1089] more --- python/sglang/srt/model_executor/model_runner.py | 9 ++++++--- python/sglang/srt/models/deepseek_v2.py | 2 ++ 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index d5c1cb3163f..b63ec4a80c1 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -489,9 +489,12 @@ def get_weight_iter(config): return iter def filter_weight_iter(iter: Iterable[Tuple[str, torch.Tensor]]): - for name, weight in iter: - if TODO: - yield name, weight + if param_categories is None: + yield from iter + else: + for name, weight in iter: + if self.model.get_: + yield name, weight def model_load_weights(model, iter): DefaultModelLoader.load_weights_and_postprocess(model, iter, target_device) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 6a960a37150..5d6b43e9008 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1608,6 +1608,8 @@ def set_embed_and_head(self, embed, head): torch.cuda.empty_cache() torch.cuda.synchronize() + def get_param_category(self, name): + return TODO class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM): pass From b57dee3d495330e2983adc22ee25b2e8c0a06c25 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 21:29:30 +0800 Subject: [PATCH 0455/1089] more --- .../sglang/srt/model_executor/model_runner.py | 2 +- python/sglang/srt/models/deepseek_v2.py | 66 ++++++++++--------- 2 files changed, 35 insertions(+), 33 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index b63ec4a80c1..0317f10aada 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -493,7 +493,7 @@ def filter_weight_iter(iter: Iterable[Tuple[str, torch.Tensor]]): yield from iter else: for name, weight in iter: - if self.model.get_: + if self.model.get_param_category(name) in param_categories: yield name, weight def model_load_weights(model, iter): diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 5d6b43e9008..f87e204fd31 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -22,10 +22,6 @@ import torch import torch.nn.functional as F -from torch import nn -from tqdm import tqdm -from transformers import PretrainedConfig - from sglang.srt.distributed import ( get_tensor_model_parallel_world_size, parallel_state, @@ -73,6 +69,9 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import DeepEPMode, add_prefix, is_cuda, is_hip +from torch import nn +from tqdm import tqdm +from transformers import PretrainedConfig _is_hip = is_hip() _is_cuda = is_cuda() @@ -412,7 +411,7 @@ def __init__( self.num_heads = num_heads assert num_heads % attn_tp_size == 0 self.num_local_heads = num_heads // attn_tp_size - self.scaling = self.qk_head_dim**-0.5 + self.scaling = self.qk_head_dim ** -0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings @@ -525,12 +524,12 @@ def forward( kv = self.kv_b_proj(kv_a)[0] kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - k_pe = latent_cache[:, :, self.kv_lora_rank :] + k_pe = latent_cache[:, :, self.kv_lora_rank:] q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - q[..., self.qk_nope_head_dim :] = q_pe + q[..., self.qk_nope_head_dim:] = q_pe k = torch.empty_like(q) k[..., : self.qk_nope_head_dim] = k_nope - k[..., self.qk_nope_head_dim :] = k_pe + k[..., self.qk_nope_head_dim:] = k_pe q = torch.nn.functional.pad(q, [0, 256 - self.qk_head_dim], value=0).view( -1, self.num_local_heads * 256 ) @@ -542,8 +541,8 @@ def forward( ) attn_output = self.attn(q, k, v, forward_batch) attn_output = attn_output.view(-1, self.num_local_heads, 256)[ - ..., : self.v_head_dim - ].reshape(-1, self.num_local_heads * self.v_head_dim) + ..., : self.v_head_dim + ].reshape(-1, self.num_local_heads * self.v_head_dim) output, _ = self.o_proj(attn_output) return output @@ -584,7 +583,7 @@ def __init__( self.num_heads = num_heads assert num_heads % attn_tp_size == 0 self.num_local_heads = num_heads // attn_tp_size - self.scaling = self.qk_head_dim**-0.5 + self.scaling = self.qk_head_dim ** -0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings @@ -771,16 +770,16 @@ def forward_normal( kv = self.kv_b_proj(kv_a)[0] kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) k_nope = kv[..., : self.qk_nope_head_dim] - v = kv[..., self.qk_nope_head_dim :] - k_pe = latent_cache[:, :, self.kv_lora_rank :] + v = kv[..., self.qk_nope_head_dim:] + k_pe = latent_cache[:, :, self.kv_lora_rank:] q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - q[..., self.qk_nope_head_dim :] = q_pe + q[..., self.qk_nope_head_dim:] = q_pe k = torch.empty_like(q) k[..., : self.qk_nope_head_dim] = k_nope - k[..., self.qk_nope_head_dim :] = k_pe + k[..., self.qk_nope_head_dim:] = k_pe latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1) - latent_cache[:, :, self.kv_lora_rank :] = k_pe + latent_cache[:, :, self.kv_lora_rank:] = k_pe # Save latent cache forward_batch.token_to_kv_pool.set_kv_buffer( @@ -833,11 +832,11 @@ def forward_absorb( v_input = self.kv_a_layernorm(v_input.contiguous()).unsqueeze(1) k_input = latent_cache.unsqueeze(1) k_input[..., : self.kv_lora_rank] = v_input - k_pe = k_input[..., self.kv_lora_rank :] + k_pe = k_input[..., self.kv_lora_rank:] q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - q_input[..., self.kv_lora_rank :] = q_pe - k_input[..., self.kv_lora_rank :] = k_pe + q_input[..., self.kv_lora_rank:] = q_pe + k_input[..., self.kv_lora_rank:] = k_pe attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch) attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank) @@ -913,15 +912,15 @@ def forward_absorb_fused_mla_rope( k_input[..., : self.kv_lora_rank] = v_input if not enable_rope_fusion: - k_pe = k_input[..., self.kv_lora_rank :] + k_pe = k_input[..., self.kv_lora_rank:] q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - q_input[..., self.kv_lora_rank :] = q_pe - k_input[..., self.kv_lora_rank :] = k_pe + q_input[..., self.kv_lora_rank:] = q_pe + k_input[..., self.kv_lora_rank:] = k_pe k_pe_output = None else: - k_pe_output = torch.empty_like(k_input[..., self.kv_lora_rank :]) + k_pe_output = torch.empty_like(k_input[..., self.kv_lora_rank:]) - q_input[..., self.kv_lora_rank :] = q_pe + q_input[..., self.kv_lora_rank:] = q_pe # attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch) # Use Fused ROPE with use_rope=OFF. @@ -978,7 +977,7 @@ def forward_absorb_fused_mla_rope( ) if enable_rope_fusion: - k_input[..., self.kv_lora_rank :] = k_pe_output + k_input[..., self.kv_lora_rank:] = k_pe_output forward_batch.token_to_kv_pool.set_kv_buffer( self.attn_mqa, forward_batch.out_cache_loc, k_input, None ) @@ -1496,7 +1495,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): self.config.moe_layer_freq, ), desc=f"Cloning {self.n_share_experts_fusion} " - "replicas of the shared expert into MoE", + "replicas of the shared expert into MoE", ): for num_repeat in range(self.n_share_experts_fusion): for suffix in suffix_list: @@ -1527,11 +1526,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", num_experts=self.config.n_routed_experts - + ( - self.n_share_experts_fusion - if self.n_share_experts_fusion is not None - else 0 - ), + + ( + self.n_share_experts_fusion + if self.n_share_experts_fusion is not None + else 0 + ), ) params_dict = dict(self.named_parameters()) @@ -1609,7 +1608,10 @@ def set_embed_and_head(self, embed, head): torch.cuda.synchronize() def get_param_category(self, name): - return TODO + if ".experts." in name: + return "moe" + return "others" + class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM): pass From d7e30c7a273a42f97b6cd530073b75f6a6b2e3e6 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 21:30:03 +0800 Subject: [PATCH 0456/1089] fmt --- .../sglang/srt/model_executor/model_runner.py | 2 +- python/sglang/srt/models/deepseek_v2.py | 61 ++++++++++--------- 2 files changed, 32 insertions(+), 31 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 0317f10aada..27868511574 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -20,7 +20,7 @@ import os import time from dataclasses import dataclass -from typing import List, Optional, Tuple, Union, Iterable +from typing import Iterable, List, Optional, Tuple, Union import torch import torch.distributed as dist diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index f87e204fd31..89531ac720d 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -22,6 +22,10 @@ import torch import torch.nn.functional as F +from torch import nn +from tqdm import tqdm +from transformers import PretrainedConfig + from sglang.srt.distributed import ( get_tensor_model_parallel_world_size, parallel_state, @@ -69,9 +73,6 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import DeepEPMode, add_prefix, is_cuda, is_hip -from torch import nn -from tqdm import tqdm -from transformers import PretrainedConfig _is_hip = is_hip() _is_cuda = is_cuda() @@ -411,7 +412,7 @@ def __init__( self.num_heads = num_heads assert num_heads % attn_tp_size == 0 self.num_local_heads = num_heads // attn_tp_size - self.scaling = self.qk_head_dim ** -0.5 + self.scaling = self.qk_head_dim**-0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings @@ -524,12 +525,12 @@ def forward( kv = self.kv_b_proj(kv_a)[0] kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - k_pe = latent_cache[:, :, self.kv_lora_rank:] + k_pe = latent_cache[:, :, self.kv_lora_rank :] q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - q[..., self.qk_nope_head_dim:] = q_pe + q[..., self.qk_nope_head_dim :] = q_pe k = torch.empty_like(q) k[..., : self.qk_nope_head_dim] = k_nope - k[..., self.qk_nope_head_dim:] = k_pe + k[..., self.qk_nope_head_dim :] = k_pe q = torch.nn.functional.pad(q, [0, 256 - self.qk_head_dim], value=0).view( -1, self.num_local_heads * 256 ) @@ -541,8 +542,8 @@ def forward( ) attn_output = self.attn(q, k, v, forward_batch) attn_output = attn_output.view(-1, self.num_local_heads, 256)[ - ..., : self.v_head_dim - ].reshape(-1, self.num_local_heads * self.v_head_dim) + ..., : self.v_head_dim + ].reshape(-1, self.num_local_heads * self.v_head_dim) output, _ = self.o_proj(attn_output) return output @@ -583,7 +584,7 @@ def __init__( self.num_heads = num_heads assert num_heads % attn_tp_size == 0 self.num_local_heads = num_heads // attn_tp_size - self.scaling = self.qk_head_dim ** -0.5 + self.scaling = self.qk_head_dim**-0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings @@ -770,16 +771,16 @@ def forward_normal( kv = self.kv_b_proj(kv_a)[0] kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) k_nope = kv[..., : self.qk_nope_head_dim] - v = kv[..., self.qk_nope_head_dim:] - k_pe = latent_cache[:, :, self.kv_lora_rank:] + v = kv[..., self.qk_nope_head_dim :] + k_pe = latent_cache[:, :, self.kv_lora_rank :] q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - q[..., self.qk_nope_head_dim:] = q_pe + q[..., self.qk_nope_head_dim :] = q_pe k = torch.empty_like(q) k[..., : self.qk_nope_head_dim] = k_nope - k[..., self.qk_nope_head_dim:] = k_pe + k[..., self.qk_nope_head_dim :] = k_pe latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1) - latent_cache[:, :, self.kv_lora_rank:] = k_pe + latent_cache[:, :, self.kv_lora_rank :] = k_pe # Save latent cache forward_batch.token_to_kv_pool.set_kv_buffer( @@ -832,11 +833,11 @@ def forward_absorb( v_input = self.kv_a_layernorm(v_input.contiguous()).unsqueeze(1) k_input = latent_cache.unsqueeze(1) k_input[..., : self.kv_lora_rank] = v_input - k_pe = k_input[..., self.kv_lora_rank:] + k_pe = k_input[..., self.kv_lora_rank :] q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - q_input[..., self.kv_lora_rank:] = q_pe - k_input[..., self.kv_lora_rank:] = k_pe + q_input[..., self.kv_lora_rank :] = q_pe + k_input[..., self.kv_lora_rank :] = k_pe attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch) attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank) @@ -912,15 +913,15 @@ def forward_absorb_fused_mla_rope( k_input[..., : self.kv_lora_rank] = v_input if not enable_rope_fusion: - k_pe = k_input[..., self.kv_lora_rank:] + k_pe = k_input[..., self.kv_lora_rank :] q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - q_input[..., self.kv_lora_rank:] = q_pe - k_input[..., self.kv_lora_rank:] = k_pe + q_input[..., self.kv_lora_rank :] = q_pe + k_input[..., self.kv_lora_rank :] = k_pe k_pe_output = None else: - k_pe_output = torch.empty_like(k_input[..., self.kv_lora_rank:]) + k_pe_output = torch.empty_like(k_input[..., self.kv_lora_rank :]) - q_input[..., self.kv_lora_rank:] = q_pe + q_input[..., self.kv_lora_rank :] = q_pe # attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch) # Use Fused ROPE with use_rope=OFF. @@ -977,7 +978,7 @@ def forward_absorb_fused_mla_rope( ) if enable_rope_fusion: - k_input[..., self.kv_lora_rank:] = k_pe_output + k_input[..., self.kv_lora_rank :] = k_pe_output forward_batch.token_to_kv_pool.set_kv_buffer( self.attn_mqa, forward_batch.out_cache_loc, k_input, None ) @@ -1495,7 +1496,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): self.config.moe_layer_freq, ), desc=f"Cloning {self.n_share_experts_fusion} " - "replicas of the shared expert into MoE", + "replicas of the shared expert into MoE", ): for num_repeat in range(self.n_share_experts_fusion): for suffix in suffix_list: @@ -1526,11 +1527,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", num_experts=self.config.n_routed_experts - + ( - self.n_share_experts_fusion - if self.n_share_experts_fusion is not None - else 0 - ), + + ( + self.n_share_experts_fusion + if self.n_share_experts_fusion is not None + else 0 + ), ) params_dict = dict(self.named_parameters()) From 1b9d1648962cef2d7e20b450937019d38fd326f3 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 21:33:51 +0800 Subject: [PATCH 0457/1089] more --- python/sglang/srt/managers/io_struct.py | 8 ++++++++ python/sglang/srt/managers/tokenizer_manager.py | 7 ++++--- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 839d343fde0..8e4209a79be 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -28,6 +28,9 @@ else: Image = Any +if TYPE_CHECKING: + from sglang.srt.managers.expert_location import ExpertLocationMetadata + from sglang.srt.managers.schedule_batch import BaseFinishReason from sglang.srt.sampling.sampling_params import SamplingParams @@ -669,6 +672,11 @@ class FlushCacheReq: pass +@dataclass +class UpdateExpertLocationMetadataReqInput: + expert_location_metadata: "ExpertLocationMetadata" + + @dataclass class UpdateWeightFromDiskReqInput: # The model path with the new weights diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 038614017d2..b1a803c1d6b 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -45,7 +45,6 @@ import zmq import zmq.asyncio from fastapi import BackgroundTasks - from sglang.srt.aio_rwlock import RWLock from sglang.srt.configs.model_config import ModelConfig from sglang.srt.disaggregation.conn import KVBootstrapServer @@ -666,6 +665,8 @@ async def dump_expert_distribution_record(self): expert_location_metadata=self.expert_location_metadata, ) + async def update_expert_location_metadata(self, obj:): + async def update_weights_from_disk( self, obj: UpdateWeightFromDiskReqInput, @@ -973,8 +974,8 @@ def _handle_batch_output( elif isinstance(recv_obj, BatchTokenIDOut): if self.server_args.stream_output and state.obj.stream: output_token_ids = recv_obj.output_ids[i][ - state.last_output_offset : - ] + state.last_output_offset: + ] state.last_output_offset = len(recv_obj.output_ids[i]) else: output_token_ids = recv_obj.output_ids[i] From c5c2790a0aa30de05b434c7c7c37587233f03927 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 21:34:11 +0800 Subject: [PATCH 0458/1089] more --- python/sglang/srt/managers/tokenizer_manager.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index b1a803c1d6b..517e302738b 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -90,7 +90,7 @@ UpdateWeightsFromDistributedReqInput, UpdateWeightsFromDistributedReqOutput, UpdateWeightsFromTensorReqInput, - UpdateWeightsFromTensorReqOutput, + UpdateWeightsFromTensorReqOutput, UpdateExpertLocationMetadataReqInput, ) from sglang.srt.managers.multimodal_processor import ( get_dummy_processor, @@ -665,7 +665,8 @@ async def dump_expert_distribution_record(self): expert_location_metadata=self.expert_location_metadata, ) - async def update_expert_location_metadata(self, obj:): + async def update_expert_location_metadata(self, obj: UpdateExpertLocationMetadataReqInput): + TODO async def update_weights_from_disk( self, From e4416869374b546401236a7612468c6a7c2cd277 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 21:34:43 +0800 Subject: [PATCH 0459/1089] more --- python/sglang/srt/entrypoints/engine.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 336e445e0c7..8f2e66008de 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -31,7 +31,6 @@ import zmq import zmq.asyncio from PIL.Image import Image - from sglang.srt.configs.model_config import ModelConfig from sglang.srt.managers.eplb_manager import EPLBManager from sglang.srt.managers.expert_location import ExpertLocationMetadata @@ -359,6 +358,9 @@ def update_weights_from_tensor( self.tokenizer_manager.update_weights_from_tensor(obj, None) ) + def update_expert_location_metadata(self, expert_location_metadata: ExpertLocationMetadata): + TODO + def update_weights_from_disk( self, model_path: str, From a938258803d39aee993722081126ead7490845c8 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 21:35:05 +0800 Subject: [PATCH 0460/1089] more --- python/sglang/srt/entrypoints/engine.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 8f2e66008de..6201751a220 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -57,7 +57,7 @@ RpcReqOutput, UpdateWeightFromDiskReqInput, UpdateWeightsFromDistributedReqInput, - UpdateWeightsFromTensorReqInput, + UpdateWeightsFromTensorReqInput, UpdateExpertLocationMetadataReqInput, ) from sglang.srt.managers.scheduler import run_scheduler_process from sglang.srt.managers.tokenizer_manager import TokenizerManager @@ -359,8 +359,14 @@ def update_weights_from_tensor( ) def update_expert_location_metadata(self, expert_location_metadata: ExpertLocationMetadata): - TODO - + obj = UpdateExpertLocationMetadataReqInput( + expert_location_metadata=expert_location_metadata, + ) + loop = asyncio.get_event_loop() + return loop.run_until_complete( + self.tokenizer_manager.update_expert_location_metadata(obj, None) + ) + def update_weights_from_disk( self, model_path: str, From 792da167785a6b29c9b7e6fa698ef1fe61e2df48 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 21:36:03 +0800 Subject: [PATCH 0461/1089] more --- python/sglang/srt/entrypoints/engine.py | 10 ++++------ python/sglang/srt/managers/io_struct.py | 2 +- python/sglang/srt/managers/tokenizer_manager.py | 4 ++-- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 6201751a220..26c9046b79c 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -57,7 +57,7 @@ RpcReqOutput, UpdateWeightFromDiskReqInput, UpdateWeightsFromDistributedReqInput, - UpdateWeightsFromTensorReqInput, UpdateExpertLocationMetadataReqInput, + UpdateWeightsFromTensorReqInput, UpdateExpertLocationReqInput, ) from sglang.srt.managers.scheduler import run_scheduler_process from sglang.srt.managers.tokenizer_manager import TokenizerManager @@ -358,14 +358,12 @@ def update_weights_from_tensor( self.tokenizer_manager.update_weights_from_tensor(obj, None) ) - def update_expert_location_metadata(self, expert_location_metadata: ExpertLocationMetadata): - obj = UpdateExpertLocationMetadataReqInput( + def update_expert_location(self, expert_location_metadata: ExpertLocationMetadata): + obj = UpdateExpertLocationReqInput( expert_location_metadata=expert_location_metadata, ) loop = asyncio.get_event_loop() - return loop.run_until_complete( - self.tokenizer_manager.update_expert_location_metadata(obj, None) - ) + return loop.run_until_complete(self.tokenizer_manager.update_expert_location(obj)) def update_weights_from_disk( self, diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 8e4209a79be..a4cf214d6a5 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -673,7 +673,7 @@ class FlushCacheReq: @dataclass -class UpdateExpertLocationMetadataReqInput: +class UpdateExpertLocationReqInput: expert_location_metadata: "ExpertLocationMetadata" diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 517e302738b..c5fb85d3533 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -90,7 +90,7 @@ UpdateWeightsFromDistributedReqInput, UpdateWeightsFromDistributedReqOutput, UpdateWeightsFromTensorReqInput, - UpdateWeightsFromTensorReqOutput, UpdateExpertLocationMetadataReqInput, + UpdateWeightsFromTensorReqOutput, UpdateExpertLocationReqInput, ) from sglang.srt.managers.multimodal_processor import ( get_dummy_processor, @@ -665,7 +665,7 @@ async def dump_expert_distribution_record(self): expert_location_metadata=self.expert_location_metadata, ) - async def update_expert_location_metadata(self, obj: UpdateExpertLocationMetadataReqInput): + async def update_expert_location(self, obj: UpdateExpertLocationReqInput): TODO async def update_weights_from_disk( From ad69c3bdd8ea58061baeee55737bd880335e2aea Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 21:36:25 +0800 Subject: [PATCH 0462/1089] more --- python/sglang/srt/managers/tokenizer_manager.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index c5fb85d3533..0b80478a6d5 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -666,6 +666,8 @@ async def dump_expert_distribution_record(self): ) async def update_expert_location(self, obj: UpdateExpertLocationReqInput): + self.auto_create_handle_loop() + TODO async def update_weights_from_disk( From 0ac176bcdaae62382fe5b71d5e7ea735b204ba14 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 21:37:51 +0800 Subject: [PATCH 0463/1089] fmt --- python/sglang/srt/entrypoints/engine.py | 8 ++++++-- python/sglang/srt/managers/tokenizer_manager.py | 8 +++++--- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 26c9046b79c..dbd11f382c9 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -31,6 +31,7 @@ import zmq import zmq.asyncio from PIL.Image import Image + from sglang.srt.configs.model_config import ModelConfig from sglang.srt.managers.eplb_manager import EPLBManager from sglang.srt.managers.expert_location import ExpertLocationMetadata @@ -55,9 +56,10 @@ ResumeMemoryOccupationReqInput, RpcReqInput, RpcReqOutput, + UpdateExpertLocationReqInput, UpdateWeightFromDiskReqInput, UpdateWeightsFromDistributedReqInput, - UpdateWeightsFromTensorReqInput, UpdateExpertLocationReqInput, + UpdateWeightsFromTensorReqInput, ) from sglang.srt.managers.scheduler import run_scheduler_process from sglang.srt.managers.tokenizer_manager import TokenizerManager @@ -363,7 +365,9 @@ def update_expert_location(self, expert_location_metadata: ExpertLocationMetadat expert_location_metadata=expert_location_metadata, ) loop = asyncio.get_event_loop() - return loop.run_until_complete(self.tokenizer_manager.update_expert_location(obj)) + return loop.run_until_complete( + self.tokenizer_manager.update_expert_location(obj) + ) def update_weights_from_disk( self, diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 0b80478a6d5..b51ac603ca9 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -45,6 +45,7 @@ import zmq import zmq.asyncio from fastapi import BackgroundTasks + from sglang.srt.aio_rwlock import RWLock from sglang.srt.configs.model_config import ModelConfig from sglang.srt.disaggregation.conn import KVBootstrapServer @@ -85,12 +86,13 @@ SessionParams, TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, + UpdateExpertLocationReqInput, UpdateWeightFromDiskReqInput, UpdateWeightFromDiskReqOutput, UpdateWeightsFromDistributedReqInput, UpdateWeightsFromDistributedReqOutput, UpdateWeightsFromTensorReqInput, - UpdateWeightsFromTensorReqOutput, UpdateExpertLocationReqInput, + UpdateWeightsFromTensorReqOutput, ) from sglang.srt.managers.multimodal_processor import ( get_dummy_processor, @@ -977,8 +979,8 @@ def _handle_batch_output( elif isinstance(recv_obj, BatchTokenIDOut): if self.server_args.stream_output and state.obj.stream: output_token_ids = recv_obj.output_ids[i][ - state.last_output_offset: - ] + state.last_output_offset : + ] state.last_output_offset = len(recv_obj.output_ids[i]) else: output_token_ids = recv_obj.output_ids[i] From 6f3fd422288ed7e120be8e24654c1b8b24fa98e2 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 21:42:24 +0800 Subject: [PATCH 0464/1089] more --- python/sglang/srt/managers/scheduler.py | 1 + python/sglang/srt/server_args.py | 6 ++++++ 2 files changed, 7 insertions(+) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 3ceee4a294a..1ecc3e1f405 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -391,6 +391,7 @@ def __init__( self.input_blocker = ( SchedulerInputBlocker(server_args, noop=self.attn_tp_rank != 0) if enable_colocated_batch_gen() + or server_args.enable_scheduler_input_blocker else None ) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 28539dcee44..a7f5fa2c597 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -185,6 +185,7 @@ class ServerArgs: warmups: Optional[str] = None n_share_experts_fusion: int = 0 disable_shared_experts_fusion: bool = False + enable_scheduler_input_blocker: bool = False # Debug tensor dumps debug_tensor_dump_output_folder: Optional[str] = None @@ -1117,6 +1118,11 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Disable shared experts fusion by setting n_share_experts_fusion to 0.", ) + parser.add_argument( + "--enable-scheduler-input-blocker", + action="store_true", + help="Enable input blocker for Scheduler.", + ) # Server warmups parser.add_argument( From c8261e8c8e6ee31645b2d64680d014aa153951c2 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 21:43:39 +0800 Subject: [PATCH 0465/1089] more --- python/sglang/srt/managers/tokenizer_manager.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 695636aea4a..9189f04cab1 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -680,6 +680,8 @@ async def dump_expert_distribution_record(self): async def update_expert_location(self, obj: UpdateExpertLocationReqInput): self.auto_create_handle_loop() + + assert self.server_args.enable_scheduler_input_blocker, f"update_expert_location requires --enable-scheduler-input-blocker" TODO From 8218a457560b6548544d73db2303b0f7c89806f1 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 21:44:14 +0800 Subject: [PATCH 0466/1089] more --- python/sglang/srt/managers/tokenizer_manager.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 9189f04cab1..198994626a7 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -45,7 +45,6 @@ import zmq import zmq.asyncio from fastapi import BackgroundTasks - from sglang.srt.aio_rwlock import RWLock from sglang.srt.configs.model_config import ModelConfig from sglang.srt.disaggregation.conn import KVBootstrapServer @@ -680,10 +679,10 @@ async def dump_expert_distribution_record(self): async def update_expert_location(self, obj: UpdateExpertLocationReqInput): self.auto_create_handle_loop() - assert self.server_args.enable_scheduler_input_blocker, f"update_expert_location requires --enable-scheduler-input-blocker" - + self._send_block_request(BlockReqType.BLOCK) TODO + self._send_block_request(BlockReqType.UNBLOCK) async def update_weights_from_disk( self, @@ -992,8 +991,8 @@ def _handle_batch_output( elif isinstance(recv_obj, BatchTokenIDOut): if self.server_args.stream_output and state.obj.stream: output_token_ids = recv_obj.output_ids[i][ - state.last_output_offset : - ] + state.last_output_offset: + ] state.last_output_offset = len(recv_obj.output_ids[i]) else: output_token_ids = recv_obj.output_ids[i] From 3a5fe8e61dd5cc36ac26ad8d5b86639c7c5c5a4f Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 21:44:57 +0800 Subject: [PATCH 0467/1089] more --- python/sglang/srt/managers/tokenizer_manager.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 198994626a7..8d52e916592 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -1218,6 +1218,10 @@ def __init__(self, sender, fan_out: int): self._ready_queue: Deque[asyncio.Future] = deque() async def __call__(self, obj): + await self.call_send(obj) + return await self.call_await() + + async def call_send(self, obj): ready_event = asyncio.Event() if self._result_event is not None or len(self._ready_queue) > 0: self._ready_queue.append(ready_event) @@ -1228,6 +1232,7 @@ async def __call__(self, obj): if obj: self._sender.send_pyobj(obj) + async def call_await(self): self._result_event = asyncio.Event() self._result_values = [] await self._result_event.wait() From a059d254d1e4d528a043c1b0a9709bd057cb121f Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 21:45:51 +0800 Subject: [PATCH 0468/1089] more --- python/sglang/srt/managers/io_struct.py | 3 +++ python/sglang/srt/managers/tokenizer_manager.py | 9 ++++++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 9c2e64754ba..28c963cd498 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -676,6 +676,9 @@ class FlushCacheReq: class UpdateExpertLocationReqInput: expert_location_metadata: "ExpertLocationMetadata" +@dataclass +class UpdateExpertLocationReqOutput: + pass @dataclass class UpdateWeightFromDiskReqInput: diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 8d52e916592..8b1dc25b71b 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -93,7 +93,7 @@ UpdateWeightsFromDistributedReqInput, UpdateWeightsFromDistributedReqOutput, UpdateWeightsFromTensorReqInput, - UpdateWeightsFromTensorReqOutput, + UpdateWeightsFromTensorReqOutput, UpdateExpertLocationReqOutput, ) from sglang.srt.managers.multimodal_processor import ( get_dummy_processor, @@ -272,6 +272,9 @@ def __init__( self.expert_distribution_communicator = _Communicator( self.send_to_scheduler, server_args.dp_size ) + self.update_expert_location_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) self._result_dispatcher = TypeBasedDispatcher( [ @@ -325,6 +328,10 @@ def __init__( ExpertDistributionReqOutput, self.expert_distribution_communicator.handle_recv, ), + ( + UpdateExpertLocationReqOutput, + self.update_expert_location_communicator.handle_recv, + ), (HealthCheckOutput, lambda x: None), ] ) From da7c69deed9724349710f137519ffcf3faeccd6a Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 21:46:13 +0800 Subject: [PATCH 0469/1089] more --- python/sglang/srt/managers/tokenizer_manager.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 8b1dc25b71b..1d0171dca4d 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -687,9 +687,11 @@ async def dump_expert_distribution_record(self): async def update_expert_location(self, obj: UpdateExpertLocationReqInput): self.auto_create_handle_loop() assert self.server_args.enable_scheduler_input_blocker, f"update_expert_location requires --enable-scheduler-input-blocker" + self._send_block_request(BlockReqType.BLOCK) - TODO + await self.update_expert_location_communicator.call_send(obj) self._send_block_request(BlockReqType.UNBLOCK) + await self.update_expert_location_communicator.call_await() async def update_weights_from_disk( self, From a509f31bb06d777c6e617bc969b0c91e8a88e084 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 21:46:48 +0800 Subject: [PATCH 0470/1089] more --- python/sglang/srt/managers/tokenizer_manager.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 1d0171dca4d..e2144d4f437 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -688,10 +688,12 @@ async def update_expert_location(self, obj: UpdateExpertLocationReqInput): self.auto_create_handle_loop() assert self.server_args.enable_scheduler_input_blocker, f"update_expert_location requires --enable-scheduler-input-blocker" + self.expert_location_metadata = None self._send_block_request(BlockReqType.BLOCK) await self.update_expert_location_communicator.call_send(obj) self._send_block_request(BlockReqType.UNBLOCK) await self.update_expert_location_communicator.call_await() + self.expert_location_metadata = obj.expert_location_metadata async def update_weights_from_disk( self, From 23d12075072aa0da9a05222cce3435169a2cbbae Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 21:46:56 +0800 Subject: [PATCH 0471/1089] more --- python/sglang/srt/managers/tokenizer_manager.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index e2144d4f437..e12c96fe5f4 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -689,10 +689,13 @@ async def update_expert_location(self, obj: UpdateExpertLocationReqInput): assert self.server_args.enable_scheduler_input_blocker, f"update_expert_location requires --enable-scheduler-input-blocker" self.expert_location_metadata = None + self._send_block_request(BlockReqType.BLOCK) await self.update_expert_location_communicator.call_send(obj) self._send_block_request(BlockReqType.UNBLOCK) + await self.update_expert_location_communicator.call_await() + self.expert_location_metadata = obj.expert_location_metadata async def update_weights_from_disk( From 5e5630973e313182d762540c9cc7d2dfb48df41d Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 21:47:43 +0800 Subject: [PATCH 0472/1089] more --- python/sglang/srt/managers/scheduler.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 0e9a275dd13..a9c7cc0e984 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -88,7 +88,7 @@ UpdateWeightsFromDistributedReqInput, UpdateWeightsFromDistributedReqOutput, UpdateWeightsFromTensorReqInput, - UpdateWeightsFromTensorReqOutput, + UpdateWeightsFromTensorReqOutput, UpdateExpertLocationReqInput, UpdateExpertLocationReqOutput, ) from sglang.srt.managers.schedule_batch import ( FINISH_ABORT, @@ -414,6 +414,7 @@ def __init__( (AbortReq, self.abort_request), (OpenSessionReqInput, self.open_session), (CloseSessionReqInput, self.close_session), + (UpdateExpertLocationReqInput, self.update_expert_location), (UpdateWeightFromDiskReqInput, self.update_weights_from_disk), (InitWeightsUpdateGroupReqInput, self.init_weights_update_group), ( @@ -1794,6 +1795,10 @@ def abort_request(self, recv_req: AbortReq): def _pause_engine(self) -> Tuple[List[Req], int]: raise NotImplementedError() + def update_expert_location(self, recv_req: UpdateExpertLocationReqInput): + TODO + return UpdateExpertLocationReqOutput() + def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput): """In-place update of the weights from disk.""" success, message = self.tp_worker.update_weights_from_disk(recv_req) From 7a4517b4596b856a74186ce392dbeb7c7de039c8 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 21:48:27 +0800 Subject: [PATCH 0473/1089] more --- python/sglang/srt/managers/scheduler.py | 31 ++++++++++++------------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index a9c7cc0e984..584a9581131 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -32,8 +32,6 @@ import setproctitle import torch import zmq -from torch.distributed import barrier - from sglang.global_config import global_config from sglang.srt.configs.model_config import ModelConfig from sglang.srt.constrained.base_grammar_backend import create_grammar_backend @@ -134,6 +132,7 @@ suppress_other_loggers, ) from sglang.utils import TypeBasedDispatcher, get_exception_traceback +from torch.distributed import barrier logger = logging.getLogger(__name__) @@ -374,8 +373,8 @@ def __init__( 1.0, ) self.new_token_ratio_decay = ( - self.init_new_token_ratio - self.min_new_token_ratio - ) / global_config.default_new_token_ratio_decay_steps + self.init_new_token_ratio - self.min_new_token_ratio + ) / global_config.default_new_token_ratio_decay_steps self.new_token_ratio = self.init_new_token_ratio # Init watchdog thread @@ -392,7 +391,7 @@ def __init__( self.input_blocker = ( SchedulerInputBlocker(server_args, noop=self.attn_tp_rank != 0) if enable_colocated_batch_gen() - or server_args.enable_scheduler_input_blocker + or server_args.enable_scheduler_input_blocker else None ) @@ -1265,10 +1264,10 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: if ( self.lora_paths and len( - lora_set - | set([req.lora_path for req in adder.can_run_list]) - | set([req.lora_path]) - ) + lora_set + | set([req.lora_path for req in adder.can_run_list]) + | set([req.lora_path]) + ) > self.max_loras_per_batch ): self.running_batch.batch_is_full = True @@ -1293,9 +1292,9 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: self.running_batch.batch_is_full = len( adder.can_run_list ) > 0 or ( - self.running_batch is not None - and not self.running_batch.is_empty() - ) + self.running_batch is not None + and not self.running_batch.is_empty() + ) else: self.running_batch.batch_is_full = True break @@ -1527,8 +1526,8 @@ def prepare_dp_attn_batch_raw( # We should have at least 1 token for sample in every case. max(extend_len - logprob_start_len, 1) for logprob_start_len, extend_len in zip( - local_batch.extend_logprob_start_lens, local_batch.extend_lens - ) + local_batch.extend_logprob_start_lens, local_batch.extend_lens + ) ] ) @@ -1795,8 +1794,8 @@ def abort_request(self, recv_req: AbortReq): def _pause_engine(self) -> Tuple[List[Req], int]: raise NotImplementedError() - def update_expert_location(self, recv_req: UpdateExpertLocationReqInput): - TODO + def update_expert_location(self, recv_req: UpdateExpertLocationReqInput): + self.tp_worker.worker.model_runner.update_expert_location(recv_req) return UpdateExpertLocationReqOutput() def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput): From e6925deae5eafea6c15de8cf8da7c817a380bdf1 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 21:48:48 +0800 Subject: [PATCH 0474/1089] more --- python/sglang/srt/model_executor/model_runner.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index dcc6d96b606..be63ad47f83 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -47,6 +47,7 @@ from sglang.srt.lora.lora_manager import LoRAManager from sglang.srt.managers.expert_distribution import expert_distribution_recorder from sglang.srt.managers.expert_location import ExpertLocationMetadata +from sglang.srt.managers.io_struct import UpdateExpertLocationReqInput from sglang.srt.managers.schedule_batch import ( get_global_expert_location_metadata, global_server_args_dict, @@ -475,6 +476,9 @@ def load_model(self): f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node." ) from None + def update_expert_location(self, recv_req: UpdateExpertLocationReqInput): + TODO + def update_weights_from_disk( self, model_path: str, load_format: str, param_categories: Optional[List[str]] ) -> tuple[bool, str]: From d87235f3a5a3a2dc40ef4b3a937e845c1b82a7f2 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 21:49:31 +0800 Subject: [PATCH 0475/1089] more --- python/sglang/srt/model_executor/model_runner.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index be63ad47f83..629f9eaf043 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -181,7 +181,7 @@ def __init__( set_global_expert_location_metadata(expert_location_metadata) # CPU offload - set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3)) + set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024 ** 3)) # Get memory before model loading min_per_gpu_memory = self.init_torch_distributed() @@ -477,8 +477,14 @@ def load_model(self): ) from None def update_expert_location(self, recv_req: UpdateExpertLocationReqInput): + logger.info("update_expert_location start") + torch.distributed.barrier() + TODO + torch.distributed.barrier() + logger.info("update_expert_location end") + def update_weights_from_disk( self, model_path: str, load_format: str, param_categories: Optional[List[str]] ) -> tuple[bool, str]: @@ -934,7 +940,7 @@ def init_double_sparsity_channel_config(self, selected_channel): key = "model.layers." + str(i) + ".self_attn" + selected_channel self.sorted_channels.append( torch.tensor(channel_config[key])[ - :, : self.server_args.ds_heavy_channel_num + :, : self.server_args.ds_heavy_channel_num ] .contiguous() .cuda() From de7dbd221236e47cf210af8aa0aa99256501a3c7 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 21:50:05 +0800 Subject: [PATCH 0476/1089] more --- python/sglang/srt/model_executor/model_runner.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 629f9eaf043..b62c0f9a96d 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -480,6 +480,7 @@ def update_expert_location(self, recv_req: UpdateExpertLocationReqInput): logger.info("update_expert_location start") torch.distributed.barrier() + get_global_expert_location_metadata().update(recv_req.expert_location_metadata) TODO torch.distributed.barrier() From 71bee874b4b2098144f63213aa04de10d19523aa Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 21:50:46 +0800 Subject: [PATCH 0477/1089] cherry pick --- python/sglang/srt/managers/expert_location.py | 16 ++++++++++++++++ python/sglang/srt/model_executor/model_runner.py | 1 - 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 273fbe6230a..6b8e9cc1249 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -119,6 +119,22 @@ def _init_common(server_args: ServerArgs): world_size=world_size, ) + def update(self, other: "ExpertLocationMetadata"): + for field in [ + "num_layers", + "num_local_physical_experts", + "num_logical_experts", + ]: + assert getattr(self, field) == getattr(other, field) + + for field in [ + "physical_to_logical_map", + "logical_to_all_physical_map", + "logical_to_rank_dispatch_physical_map", + ]: + # Cannot update address to avoid breaking CUDA graph + getattr(self, field)[...] = getattr(other, field) + # -------------------------------- usage ------------------------------------ def local_physical_to_physical(self, rank: int, local_physical_expert_index: int): diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index b62c0f9a96d..862cc84543e 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -24,7 +24,6 @@ import torch import torch.distributed as dist - from sglang.srt.configs.device_config import DeviceConfig from sglang.srt.configs.load_config import LoadConfig from sglang.srt.configs.model_config import AttentionArch, ModelConfig From 15124de0d85ee3cbc721a0900dbd800211fdde76 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 21:51:09 +0800 Subject: [PATCH 0478/1089] more --- python/sglang/srt/managers/expert_location.py | 25 +++++++++++-------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 6b8e9cc1249..3c82a5c5ab0 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -3,7 +3,6 @@ from typing import List, Optional import torch - from sglang.srt.configs.model_config import ModelConfig from sglang.srt.managers import deepseek_eplb from sglang.srt.managers.schedule_batch import global_server_args_dict @@ -33,8 +32,8 @@ def init_trivial(server_args: ServerArgs): num_logical_experts = model_config_for_expert_location.num_logical_experts physical_to_logical_map = ( - torch.arange(0, num_physical_experts).repeat(num_layers, 1) - % num_logical_experts + torch.arange(0, num_physical_experts).repeat(num_layers, 1) + % num_logical_experts ) # Highly inefficient, but we do not care since we will use EPLB distribution logic logical_to_all_physical_map = torch.arange(0, num_logical_experts).repeat( @@ -49,7 +48,7 @@ def init_trivial(server_args: ServerArgs): @staticmethod def init_by_mapping( - server_args: ServerArgs, physical_to_logical_map, logical_to_all_physical_map + server_args: ServerArgs, physical_to_logical_map, logical_to_all_physical_map ): if not isinstance(physical_to_logical_map, torch.Tensor): physical_to_logical_map = torch.tensor(physical_to_logical_map) @@ -104,8 +103,8 @@ def _init_common(server_args: ServerArgs): ) num_physical_experts = ( - model_config_for_expert_location.num_logical_experts - + server_args.ep_num_redundant_experts + model_config_for_expert_location.num_logical_experts + + server_args.ep_num_redundant_experts ) # TODO consider case when DP attention is disabled and DP > 1 world_size = server_args.tp_size @@ -129,12 +128,16 @@ def update(self, other: "ExpertLocationMetadata"): for field in [ "physical_to_logical_map", - "logical_to_all_physical_map", "logical_to_rank_dispatch_physical_map", ]: # Cannot update address to avoid breaking CUDA graph getattr(self, field)[...] = getattr(other, field) + for field in [ + "logical_to_all_physical_map", + ]: + setattr(self, field, getattr(other, field)) + # -------------------------------- usage ------------------------------------ def local_physical_to_physical(self, rank: int, local_physical_expert_index: int): @@ -144,7 +147,7 @@ def physical_to_local_physical(self, global_physical_expert_index: int): return global_physical_expert_index % self.num_local_physical_experts def logical_to_all_physical( - self, layer_id: int, logical_expert_id: int + self, layer_id: int, logical_expert_id: int ) -> List[int]: return self.logical_to_all_physical_raw( self.logical_to_all_physical_map, layer_id, logical_expert_id @@ -152,7 +155,7 @@ def logical_to_all_physical( @staticmethod def logical_to_all_physical_raw( - logical_to_all_physical_map, layer_id: int, logical_expert_id: int + logical_to_all_physical_map, layer_id: int, logical_expert_id: int ) -> List[int]: return [ physical_expert_id @@ -164,8 +167,8 @@ def logical_to_all_physical_raw( def _compute_logical_to_rank_dispatch_physical_map( - logical_to_all_physical_map: torch.Tensor, - num_gpus: int, + logical_to_all_physical_map: torch.Tensor, + num_gpus: int, ): # TODO maybe improve this algorithm (e.g. ensure it is really balanced) # This is rarely called, so we use for loops for maximum clarity From c362860e4909d6c460bdf1347e7ce1118b16951b Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 21:51:57 +0800 Subject: [PATCH 0479/1089] more --- python/sglang/srt/model_executor/model_runner.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 862cc84543e..b0bfbf8fcce 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -480,7 +480,11 @@ def update_expert_location(self, recv_req: UpdateExpertLocationReqInput): torch.distributed.barrier() get_global_expert_location_metadata().update(recv_req.expert_location_metadata) - TODO + self.update_weights_from_disk( + model_path=TODO, + load_format=TODO, + param_categories=["moe"], + ) torch.distributed.barrier() logger.info("update_expert_location end") From d242b7922d31fef2d885bfdeebb13fd97173c950 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 21:52:15 +0800 Subject: [PATCH 0480/1089] more --- python/sglang/srt/model_executor/model_runner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index b0bfbf8fcce..9d7cfab63a1 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -481,8 +481,8 @@ def update_expert_location(self, recv_req: UpdateExpertLocationReqInput): get_global_expert_location_metadata().update(recv_req.expert_location_metadata) self.update_weights_from_disk( - model_path=TODO, - load_format=TODO, + model_path=self.model_config.model_path, + load_format=self.server_args.load_format, param_categories=["moe"], ) From 6e5d979c77b2a43b89edf424e3f43bfdcd1b1561 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 21:52:41 +0800 Subject: [PATCH 0481/1089] more --- python/sglang/srt/model_executor/model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 9d7cfab63a1..776562b16e2 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -486,8 +486,8 @@ def update_expert_location(self, recv_req: UpdateExpertLocationReqInput): param_categories=["moe"], ) - torch.distributed.barrier() logger.info("update_expert_location end") + torch.distributed.barrier() def update_weights_from_disk( self, model_path: str, load_format: str, param_categories: Optional[List[str]] From 44f308df8cebe563d2eff4f8619ef59a619dfb67 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 21:53:45 +0800 Subject: [PATCH 0482/1089] more --- python/sglang/srt/model_executor/model_runner.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 776562b16e2..9493e2cb947 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -480,6 +480,8 @@ def update_expert_location(self, recv_req: UpdateExpertLocationReqInput): torch.distributed.barrier() get_global_expert_location_metadata().update(recv_req.expert_location_metadata) + + # We may be able to further reduce lock time by faster copying, pre-transfering, etc self.update_weights_from_disk( model_path=self.model_config.model_path, load_format=self.server_args.load_format, From b98d8cca29d89915d10ece60a2ccb8701b8d960a Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 21:55:40 +0800 Subject: [PATCH 0483/1089] more --- python/sglang/srt/managers/eplb_manager.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index ccbcaca6c19..58d2ea8391a 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -1,3 +1,4 @@ +import asyncio from pathlib import Path from typing import TYPE_CHECKING @@ -20,7 +21,7 @@ def __init__(self, server_args: ServerArgs): self._server_args = server_args self._expert_distribution_storage = ExpertDistributionStorage( dir_data=Path(self._server_args.eplb_storage_dir) - / "expert_distribution_storage" + / "expert_distribution_storage" ) def bind(self, tokenizer_manager: "TokenizerManager"): @@ -28,7 +29,12 @@ def bind(self, tokenizer_manager: "TokenizerManager"): async def handle_loop(self): await self._expert_distribution_storage.start() - # TODO auto call rebalance, etc, when Engine supports that + while True: + await asyncio.sleep(TODO) + self.rebalance() + + def rebalance(self): + TODO def compute_expert_location_metadata(self): logical_count = self._expert_distribution_storage.get_last_snapshot()[ From 5f3d7eaf1636d6e903e6d5bf5784bbc280e11afa Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 21:56:58 +0800 Subject: [PATCH 0484/1089] more --- python/sglang/srt/managers/eplb_manager.py | 2 +- python/sglang/srt/server_args.py | 77 ++++++++++++---------- 2 files changed, 43 insertions(+), 36 deletions(-) diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index 58d2ea8391a..b05ccf56e24 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -30,7 +30,7 @@ def bind(self, tokenizer_manager: "TokenizerManager"): async def handle_loop(self): await self._expert_distribution_storage.start() while True: - await asyncio.sleep(TODO) + await asyncio.sleep(self._server_args.eplb_update_period) self.rebalance() def rebalance(self): diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 7f3ef192cb0..24e823e273d 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -166,6 +166,7 @@ class ServerArgs: init_expert_location: Optional[str] = None enable_eplb: bool = False eplb_storage_dir: str = "/tmp/eplb_storage" + eplb_rebalance_period: Optional[int] = None enable_torch_compile: bool = False torch_compile_max_bs: int = 32 cuda_graph_max_bs: Optional[int] = None @@ -424,8 +425,8 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.tokenizer_mode, choices=["auto", "slow"], help="Tokenizer mode. 'auto' will use the fast " - "tokenizer if available, and 'slow' will " - "always use the slow tokenizer.", + "tokenizer if available, and 'slow' will " + "always use the slow tokenizer.", ) parser.add_argument( "--skip-tokenizer-init", @@ -449,21 +450,21 @@ def add_cli_args(parser: argparse.ArgumentParser): "remote", ], help="The format of the model weights to load. " - '"auto" will try to load the weights in the safetensors format ' - "and fall back to the pytorch bin format if safetensors format " - "is not available. " - '"pt" will load the weights in the pytorch bin format. ' - '"safetensors" will load the weights in the safetensors format. ' - '"npcache" will load the weights in pytorch format and store ' - "a numpy cache to speed up the loading. " - '"dummy" will initialize the weights with random values, ' - "which is mainly for profiling." - '"gguf" will load the weights in the gguf format. ' - '"bitsandbytes" will load the weights using bitsandbytes ' - "quantization." - '"layered" loads weights layer by layer so that one can quantize a ' - "layer before loading another to make the peak memory envelope " - "smaller.", + '"auto" will try to load the weights in the safetensors format ' + "and fall back to the pytorch bin format if safetensors format " + "is not available. " + '"pt" will load the weights in the pytorch bin format. ' + '"safetensors" will load the weights in the safetensors format. ' + '"npcache" will load the weights in pytorch format and store ' + "a numpy cache to speed up the loading. " + '"dummy" will initialize the weights with random values, ' + "which is mainly for profiling." + '"gguf" will load the weights in the gguf format. ' + '"bitsandbytes" will load the weights using bitsandbytes ' + "quantization." + '"layered" loads weights layer by layer so that one can quantize a ' + "layer before loading another to make the peak memory envelope " + "smaller.", ) parser.add_argument( "--trust-remote-code", @@ -476,13 +477,13 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.dtype, choices=["auto", "half", "float16", "bfloat16", "float", "float32"], help="Data type for model weights and activations.\n\n" - '* "auto" will use FP16 precision for FP32 and FP16 models, and ' - "BF16 precision for BF16 models.\n" - '* "half" for FP16. Recommended for AWQ quantization.\n' - '* "float16" is the same as "half".\n' - '* "bfloat16" for a balance between precision and range.\n' - '* "float" is shorthand for FP32 precision.\n' - '* "float32" for FP32 precision.', + '* "auto" will use FP16 precision for FP32 and FP16 models, and ' + "BF16 precision for BF16 models.\n" + '* "half" for FP16. Recommended for AWQ quantization.\n' + '* "float16" is the same as "half".\n' + '* "bfloat16" for a balance between precision and range.\n' + '* "float" is shorthand for FP32 precision.\n' + '* "float32" for FP32 precision.', ) parser.add_argument( "--kv-cache-dtype", @@ -517,9 +518,9 @@ def add_cli_args(parser: argparse.ArgumentParser): type=nullable_str, default=None, help="Path to the JSON file containing the KV cache " - "scaling factors. This should generally be supplied, when " - "KV cache dtype is FP8. Otherwise, KV cache scaling factors " - "default to 1.0, which may cause accuracy issues. ", + "scaling factors. This should generally be supplied, when " + "KV cache dtype is FP8. Otherwise, KV cache scaling factors " + "default to 1.0, which may cause accuracy issues. ", ) parser.add_argument( "--context-length", @@ -561,8 +562,8 @@ def add_cli_args(parser: argparse.ArgumentParser): type=str, default=None, help="The specific model version to use. It can be a branch " - "name, a tag name, or a commit id. If unspecified, will use " - "the default version.", + "name, a tag name, or a commit id. If unspecified, will use " + "the default version.", ) # Memory and scheduling parser.add_argument( @@ -582,7 +583,7 @@ def add_cli_args(parser: argparse.ArgumentParser): type=int, default=ServerArgs.max_total_tokens, help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. " - "This option is typically used for development and debugging purposes.", + "This option is typically used for development and debugging purposes.", ) parser.add_argument( "--chunked-prefill-size", @@ -1047,7 +1048,7 @@ def add_cli_args(parser: argparse.ArgumentParser): "--triton-attention-reduce-in-fp32", action="store_true", help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16." - "This only affects Triton attention kernels.", + "This only affects Triton attention kernels.", ) parser.add_argument( "--triton-attention-num-kv-splits", @@ -1060,8 +1061,8 @@ def add_cli_args(parser: argparse.ArgumentParser): type=int, default=ServerArgs.num_continuous_decode_steps, help="Run multiple continuous decoding steps to reduce scheduling overhead. " - "This can potentially increase throughput but may also increase time-to-first-token latency. " - "The default value is 1, meaning only run one decoding step at a time.", + "This can potentially increase throughput but may also increase time-to-first-token latency. " + "The default value is 1, meaning only run one decoding step at a time.", ) parser.add_argument( "--delete-ckpt-after-loading", @@ -1130,6 +1131,12 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.eplb_storage_dir, help="Storage directory of EPLB subsystem.", ) + parser.add_argument( + "--eplb-rebalance-period", + type=int, + default=ServerArgs.eplb_rebalance_period, + help="Time (inm seconds) to automatically trigger a EPLB re-balance.", + ) parser.add_argument( "--deepep-mode", type=str, @@ -1143,7 +1150,7 @@ def add_cli_args(parser: argparse.ArgumentParser): type=int, default=0, help="The number of shared_experts need to be replica to fuse with normal experts in deepseek v3/r1 " - "we use tp_size by default.", + "we use tp_size by default.", ) parser.add_argument( "--disable-shared-experts-fusion", @@ -1162,7 +1169,7 @@ def add_cli_args(parser: argparse.ArgumentParser): type=str, required=False, help="Specify custom warmup functions (csv) to run before server starts eg. --warmups=warmup_name1,warmup_name2 " - "will run the functions `warmup_name1` and `warmup_name2` specified in warmup.py before the server starts listening for requests", + "will run the functions `warmup_name1` and `warmup_name2` specified in warmup.py before the server starts listening for requests", ) # Debug tensor dumps From b6ac08012e96ee3b138719bd2c927a6ca6e3434a Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 21:57:09 +0800 Subject: [PATCH 0485/1089] more --- python/sglang/srt/managers/eplb_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index b05ccf56e24..30e306a2e4e 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -30,7 +30,7 @@ def bind(self, tokenizer_manager: "TokenizerManager"): async def handle_loop(self): await self._expert_distribution_storage.start() while True: - await asyncio.sleep(self._server_args.eplb_update_period) + await asyncio.sleep(self._server_args.eplb_rebalance_period or 100000000) self.rebalance() def rebalance(self): From 0a0e1d0cfaf7d423a01880c22eef6f5cd1df6d86 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 21:57:39 +0800 Subject: [PATCH 0486/1089] more --- python/sglang/srt/managers/eplb_manager.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index 30e306a2e4e..87136036e9b 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -1,4 +1,5 @@ import asyncio +import logging from pathlib import Path from typing import TYPE_CHECKING @@ -14,6 +15,8 @@ if TYPE_CHECKING: from sglang.srt.managers.tokenizer_manager import TokenizerManager +logger = logging.getLogger(__name__) + class EPLBManager: def __init__(self, server_args: ServerArgs): @@ -31,10 +34,12 @@ async def handle_loop(self): await self._expert_distribution_storage.start() while True: await asyncio.sleep(self._server_args.eplb_rebalance_period or 100000000) - self.rebalance() + await self.rebalance() - def rebalance(self): + async def rebalance(self): + logger.info("rebalance start") TODO + logger.info("rebalance end") def compute_expert_location_metadata(self): logical_count = self._expert_distribution_storage.get_last_snapshot()[ From 2e3def943feaf30df8315228b58559b13ad77a75 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 21:58:39 +0800 Subject: [PATCH 0487/1089] more --- python/sglang/srt/managers/eplb_manager.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index 87136036e9b..0a691a89986 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -34,6 +34,7 @@ async def handle_loop(self): await self._expert_distribution_storage.start() while True: await asyncio.sleep(self._server_args.eplb_rebalance_period or 100000000) + self.save_expert_distribution() await self.rebalance() async def rebalance(self): @@ -41,6 +42,9 @@ async def rebalance(self): TODO logger.info("rebalance end") + def save_expert_distribution(self): + self._expert_distribution_storage.save_current() + def compute_expert_location_metadata(self): logical_count = self._expert_distribution_storage.get_last_snapshot()[ "logical_count" From e72d4b18db1a969426f738bd241752a3029d0ead Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 22:00:22 +0800 Subject: [PATCH 0488/1089] more --- python/sglang/srt/managers/eplb_manager.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index 0a691a89986..efd11c40de2 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -10,6 +10,7 @@ ExpertLocationMetadata, ModelConfigForExpertLocation, ) +from sglang.srt.managers.io_struct import UpdateExpertLocationReqInput from sglang.srt.server_args import ServerArgs if TYPE_CHECKING: @@ -28,6 +29,7 @@ def __init__(self, server_args: ServerArgs): ) def bind(self, tokenizer_manager: "TokenizerManager"): + self._tokenizer_manager = tokenizer_manager self._expert_distribution_storage.bind(tokenizer_manager) async def handle_loop(self): @@ -38,9 +40,9 @@ async def handle_loop(self): await self.rebalance() async def rebalance(self): - logger.info("rebalance start") - TODO - logger.info("rebalance end") + expert_location_metadata = self.compute_expert_location_metadata() + await self._tokenizer_manager.update_expert_location( + UpdateExpertLocationReqInput(expert_location_metadata=expert_location_metadata)) def save_expert_distribution(self): self._expert_distribution_storage.save_current() From bd3cc0cd25a879065e50a0e6b43c4287c950c11d Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 22:02:40 +0800 Subject: [PATCH 0489/1089] more --- python/sglang/srt/managers/eplb_manager.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index efd11c40de2..4eebf7c8083 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -35,7 +35,9 @@ def bind(self, tokenizer_manager: "TokenizerManager"): async def handle_loop(self): await self._expert_distribution_storage.start() while True: - await asyncio.sleep(self._server_args.eplb_rebalance_period or 100000000) + sleep_time = self._server_args.eplb_rebalance_period or 100000000 + logger.info(f"Sleep {sleep_time} seconds before automatically trigger rebalancing") + await asyncio.sleep(sleep_time) self.save_expert_distribution() await self.rebalance() From 9196eb55ec7907627dde2a2218a781be843caceb Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 22:03:07 +0800 Subject: [PATCH 0490/1089] more --- python/sglang/srt/managers/eplb_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index 4eebf7c8083..b52fcf9921b 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -35,7 +35,7 @@ def bind(self, tokenizer_manager: "TokenizerManager"): async def handle_loop(self): await self._expert_distribution_storage.start() while True: - sleep_time = self._server_args.eplb_rebalance_period or 100000000 + sleep_time = self._server_args.eplb_rebalance_period or 1000000000 logger.info(f"Sleep {sleep_time} seconds before automatically trigger rebalancing") await asyncio.sleep(sleep_time) self.save_expert_distribution() From d1563bdfd47c8fc3340701b6f5e8aaa8a1b2094e Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 22:04:09 +0800 Subject: [PATCH 0491/1089] fmt --- python/sglang/srt/managers/eplb_manager.py | 11 ++- python/sglang/srt/managers/expert_location.py | 19 ++--- python/sglang/srt/managers/io_struct.py | 2 + python/sglang/srt/managers/scheduler.py | 31 ++++---- .../sglang/srt/managers/tokenizer_manager.py | 12 ++-- .../sglang/srt/model_executor/model_runner.py | 5 +- python/sglang/srt/server_args.py | 70 +++++++++---------- 7 files changed, 83 insertions(+), 67 deletions(-) diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index b52fcf9921b..8aa7deefec0 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -25,7 +25,7 @@ def __init__(self, server_args: ServerArgs): self._server_args = server_args self._expert_distribution_storage = ExpertDistributionStorage( dir_data=Path(self._server_args.eplb_storage_dir) - / "expert_distribution_storage" + / "expert_distribution_storage" ) def bind(self, tokenizer_manager: "TokenizerManager"): @@ -36,7 +36,9 @@ async def handle_loop(self): await self._expert_distribution_storage.start() while True: sleep_time = self._server_args.eplb_rebalance_period or 1000000000 - logger.info(f"Sleep {sleep_time} seconds before automatically trigger rebalancing") + logger.info( + f"Sleep {sleep_time} seconds before automatically trigger rebalancing" + ) await asyncio.sleep(sleep_time) self.save_expert_distribution() await self.rebalance() @@ -44,7 +46,10 @@ async def handle_loop(self): async def rebalance(self): expert_location_metadata = self.compute_expert_location_metadata() await self._tokenizer_manager.update_expert_location( - UpdateExpertLocationReqInput(expert_location_metadata=expert_location_metadata)) + UpdateExpertLocationReqInput( + expert_location_metadata=expert_location_metadata + ) + ) def save_expert_distribution(self): self._expert_distribution_storage.save_current() diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 3c82a5c5ab0..48bef6a39ab 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -3,6 +3,7 @@ from typing import List, Optional import torch + from sglang.srt.configs.model_config import ModelConfig from sglang.srt.managers import deepseek_eplb from sglang.srt.managers.schedule_batch import global_server_args_dict @@ -32,8 +33,8 @@ def init_trivial(server_args: ServerArgs): num_logical_experts = model_config_for_expert_location.num_logical_experts physical_to_logical_map = ( - torch.arange(0, num_physical_experts).repeat(num_layers, 1) - % num_logical_experts + torch.arange(0, num_physical_experts).repeat(num_layers, 1) + % num_logical_experts ) # Highly inefficient, but we do not care since we will use EPLB distribution logic logical_to_all_physical_map = torch.arange(0, num_logical_experts).repeat( @@ -48,7 +49,7 @@ def init_trivial(server_args: ServerArgs): @staticmethod def init_by_mapping( - server_args: ServerArgs, physical_to_logical_map, logical_to_all_physical_map + server_args: ServerArgs, physical_to_logical_map, logical_to_all_physical_map ): if not isinstance(physical_to_logical_map, torch.Tensor): physical_to_logical_map = torch.tensor(physical_to_logical_map) @@ -103,8 +104,8 @@ def _init_common(server_args: ServerArgs): ) num_physical_experts = ( - model_config_for_expert_location.num_logical_experts - + server_args.ep_num_redundant_experts + model_config_for_expert_location.num_logical_experts + + server_args.ep_num_redundant_experts ) # TODO consider case when DP attention is disabled and DP > 1 world_size = server_args.tp_size @@ -147,7 +148,7 @@ def physical_to_local_physical(self, global_physical_expert_index: int): return global_physical_expert_index % self.num_local_physical_experts def logical_to_all_physical( - self, layer_id: int, logical_expert_id: int + self, layer_id: int, logical_expert_id: int ) -> List[int]: return self.logical_to_all_physical_raw( self.logical_to_all_physical_map, layer_id, logical_expert_id @@ -155,7 +156,7 @@ def logical_to_all_physical( @staticmethod def logical_to_all_physical_raw( - logical_to_all_physical_map, layer_id: int, logical_expert_id: int + logical_to_all_physical_map, layer_id: int, logical_expert_id: int ) -> List[int]: return [ physical_expert_id @@ -167,8 +168,8 @@ def logical_to_all_physical_raw( def _compute_logical_to_rank_dispatch_physical_map( - logical_to_all_physical_map: torch.Tensor, - num_gpus: int, + logical_to_all_physical_map: torch.Tensor, + num_gpus: int, ): # TODO maybe improve this algorithm (e.g. ensure it is really balanced) # This is rarely called, so we use for loops for maximum clarity diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 28c963cd498..6516a86cc7e 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -676,10 +676,12 @@ class FlushCacheReq: class UpdateExpertLocationReqInput: expert_location_metadata: "ExpertLocationMetadata" + @dataclass class UpdateExpertLocationReqOutput: pass + @dataclass class UpdateWeightFromDiskReqInput: # The model path with the new weights diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 584a9581131..7cc90196adf 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -32,6 +32,8 @@ import setproctitle import torch import zmq +from torch.distributed import barrier + from sglang.global_config import global_config from sglang.srt.configs.model_config import ModelConfig from sglang.srt.constrained.base_grammar_backend import create_grammar_backend @@ -81,12 +83,14 @@ SetInternalStateReqOutput, TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, + UpdateExpertLocationReqInput, + UpdateExpertLocationReqOutput, UpdateWeightFromDiskReqInput, UpdateWeightFromDiskReqOutput, UpdateWeightsFromDistributedReqInput, UpdateWeightsFromDistributedReqOutput, UpdateWeightsFromTensorReqInput, - UpdateWeightsFromTensorReqOutput, UpdateExpertLocationReqInput, UpdateExpertLocationReqOutput, + UpdateWeightsFromTensorReqOutput, ) from sglang.srt.managers.schedule_batch import ( FINISH_ABORT, @@ -132,7 +136,6 @@ suppress_other_loggers, ) from sglang.utils import TypeBasedDispatcher, get_exception_traceback -from torch.distributed import barrier logger = logging.getLogger(__name__) @@ -373,8 +376,8 @@ def __init__( 1.0, ) self.new_token_ratio_decay = ( - self.init_new_token_ratio - self.min_new_token_ratio - ) / global_config.default_new_token_ratio_decay_steps + self.init_new_token_ratio - self.min_new_token_ratio + ) / global_config.default_new_token_ratio_decay_steps self.new_token_ratio = self.init_new_token_ratio # Init watchdog thread @@ -391,7 +394,7 @@ def __init__( self.input_blocker = ( SchedulerInputBlocker(server_args, noop=self.attn_tp_rank != 0) if enable_colocated_batch_gen() - or server_args.enable_scheduler_input_blocker + or server_args.enable_scheduler_input_blocker else None ) @@ -1264,10 +1267,10 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: if ( self.lora_paths and len( - lora_set - | set([req.lora_path for req in adder.can_run_list]) - | set([req.lora_path]) - ) + lora_set + | set([req.lora_path for req in adder.can_run_list]) + | set([req.lora_path]) + ) > self.max_loras_per_batch ): self.running_batch.batch_is_full = True @@ -1292,9 +1295,9 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: self.running_batch.batch_is_full = len( adder.can_run_list ) > 0 or ( - self.running_batch is not None - and not self.running_batch.is_empty() - ) + self.running_batch is not None + and not self.running_batch.is_empty() + ) else: self.running_batch.batch_is_full = True break @@ -1526,8 +1529,8 @@ def prepare_dp_attn_batch_raw( # We should have at least 1 token for sample in every case. max(extend_len - logprob_start_len, 1) for logprob_start_len, extend_len in zip( - local_batch.extend_logprob_start_lens, local_batch.extend_lens - ) + local_batch.extend_logprob_start_lens, local_batch.extend_lens + ) ] ) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index e12c96fe5f4..200b9bfd65c 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -45,6 +45,7 @@ import zmq import zmq.asyncio from fastapi import BackgroundTasks + from sglang.srt.aio_rwlock import RWLock from sglang.srt.configs.model_config import ModelConfig from sglang.srt.disaggregation.conn import KVBootstrapServer @@ -88,12 +89,13 @@ TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, UpdateExpertLocationReqInput, + UpdateExpertLocationReqOutput, UpdateWeightFromDiskReqInput, UpdateWeightFromDiskReqOutput, UpdateWeightsFromDistributedReqInput, UpdateWeightsFromDistributedReqOutput, UpdateWeightsFromTensorReqInput, - UpdateWeightsFromTensorReqOutput, UpdateExpertLocationReqOutput, + UpdateWeightsFromTensorReqOutput, ) from sglang.srt.managers.multimodal_processor import ( get_dummy_processor, @@ -686,7 +688,9 @@ async def dump_expert_distribution_record(self): async def update_expert_location(self, obj: UpdateExpertLocationReqInput): self.auto_create_handle_loop() - assert self.server_args.enable_scheduler_input_blocker, f"update_expert_location requires --enable-scheduler-input-blocker" + assert ( + self.server_args.enable_scheduler_input_blocker + ), f"update_expert_location requires --enable-scheduler-input-blocker" self.expert_location_metadata = None @@ -1005,8 +1009,8 @@ def _handle_batch_output( elif isinstance(recv_obj, BatchTokenIDOut): if self.server_args.stream_output and state.obj.stream: output_token_ids = recv_obj.output_ids[i][ - state.last_output_offset: - ] + state.last_output_offset : + ] state.last_output_offset = len(recv_obj.output_ids[i]) else: output_token_ids = recv_obj.output_ids[i] diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 9493e2cb947..41e862e8228 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -24,6 +24,7 @@ import torch import torch.distributed as dist + from sglang.srt.configs.device_config import DeviceConfig from sglang.srt.configs.load_config import LoadConfig from sglang.srt.configs.model_config import AttentionArch, ModelConfig @@ -180,7 +181,7 @@ def __init__( set_global_expert_location_metadata(expert_location_metadata) # CPU offload - set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024 ** 3)) + set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3)) # Get memory before model loading min_per_gpu_memory = self.init_torch_distributed() @@ -946,7 +947,7 @@ def init_double_sparsity_channel_config(self, selected_channel): key = "model.layers." + str(i) + ".self_attn" + selected_channel self.sorted_channels.append( torch.tensor(channel_config[key])[ - :, : self.server_args.ds_heavy_channel_num + :, : self.server_args.ds_heavy_channel_num ] .contiguous() .cuda() diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 24e823e273d..f7bf4981e86 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -425,8 +425,8 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.tokenizer_mode, choices=["auto", "slow"], help="Tokenizer mode. 'auto' will use the fast " - "tokenizer if available, and 'slow' will " - "always use the slow tokenizer.", + "tokenizer if available, and 'slow' will " + "always use the slow tokenizer.", ) parser.add_argument( "--skip-tokenizer-init", @@ -450,21 +450,21 @@ def add_cli_args(parser: argparse.ArgumentParser): "remote", ], help="The format of the model weights to load. " - '"auto" will try to load the weights in the safetensors format ' - "and fall back to the pytorch bin format if safetensors format " - "is not available. " - '"pt" will load the weights in the pytorch bin format. ' - '"safetensors" will load the weights in the safetensors format. ' - '"npcache" will load the weights in pytorch format and store ' - "a numpy cache to speed up the loading. " - '"dummy" will initialize the weights with random values, ' - "which is mainly for profiling." - '"gguf" will load the weights in the gguf format. ' - '"bitsandbytes" will load the weights using bitsandbytes ' - "quantization." - '"layered" loads weights layer by layer so that one can quantize a ' - "layer before loading another to make the peak memory envelope " - "smaller.", + '"auto" will try to load the weights in the safetensors format ' + "and fall back to the pytorch bin format if safetensors format " + "is not available. " + '"pt" will load the weights in the pytorch bin format. ' + '"safetensors" will load the weights in the safetensors format. ' + '"npcache" will load the weights in pytorch format and store ' + "a numpy cache to speed up the loading. " + '"dummy" will initialize the weights with random values, ' + "which is mainly for profiling." + '"gguf" will load the weights in the gguf format. ' + '"bitsandbytes" will load the weights using bitsandbytes ' + "quantization." + '"layered" loads weights layer by layer so that one can quantize a ' + "layer before loading another to make the peak memory envelope " + "smaller.", ) parser.add_argument( "--trust-remote-code", @@ -477,13 +477,13 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.dtype, choices=["auto", "half", "float16", "bfloat16", "float", "float32"], help="Data type for model weights and activations.\n\n" - '* "auto" will use FP16 precision for FP32 and FP16 models, and ' - "BF16 precision for BF16 models.\n" - '* "half" for FP16. Recommended for AWQ quantization.\n' - '* "float16" is the same as "half".\n' - '* "bfloat16" for a balance between precision and range.\n' - '* "float" is shorthand for FP32 precision.\n' - '* "float32" for FP32 precision.', + '* "auto" will use FP16 precision for FP32 and FP16 models, and ' + "BF16 precision for BF16 models.\n" + '* "half" for FP16. Recommended for AWQ quantization.\n' + '* "float16" is the same as "half".\n' + '* "bfloat16" for a balance between precision and range.\n' + '* "float" is shorthand for FP32 precision.\n' + '* "float32" for FP32 precision.', ) parser.add_argument( "--kv-cache-dtype", @@ -518,9 +518,9 @@ def add_cli_args(parser: argparse.ArgumentParser): type=nullable_str, default=None, help="Path to the JSON file containing the KV cache " - "scaling factors. This should generally be supplied, when " - "KV cache dtype is FP8. Otherwise, KV cache scaling factors " - "default to 1.0, which may cause accuracy issues. ", + "scaling factors. This should generally be supplied, when " + "KV cache dtype is FP8. Otherwise, KV cache scaling factors " + "default to 1.0, which may cause accuracy issues. ", ) parser.add_argument( "--context-length", @@ -562,8 +562,8 @@ def add_cli_args(parser: argparse.ArgumentParser): type=str, default=None, help="The specific model version to use. It can be a branch " - "name, a tag name, or a commit id. If unspecified, will use " - "the default version.", + "name, a tag name, or a commit id. If unspecified, will use " + "the default version.", ) # Memory and scheduling parser.add_argument( @@ -583,7 +583,7 @@ def add_cli_args(parser: argparse.ArgumentParser): type=int, default=ServerArgs.max_total_tokens, help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. " - "This option is typically used for development and debugging purposes.", + "This option is typically used for development and debugging purposes.", ) parser.add_argument( "--chunked-prefill-size", @@ -1048,7 +1048,7 @@ def add_cli_args(parser: argparse.ArgumentParser): "--triton-attention-reduce-in-fp32", action="store_true", help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16." - "This only affects Triton attention kernels.", + "This only affects Triton attention kernels.", ) parser.add_argument( "--triton-attention-num-kv-splits", @@ -1061,8 +1061,8 @@ def add_cli_args(parser: argparse.ArgumentParser): type=int, default=ServerArgs.num_continuous_decode_steps, help="Run multiple continuous decoding steps to reduce scheduling overhead. " - "This can potentially increase throughput but may also increase time-to-first-token latency. " - "The default value is 1, meaning only run one decoding step at a time.", + "This can potentially increase throughput but may also increase time-to-first-token latency. " + "The default value is 1, meaning only run one decoding step at a time.", ) parser.add_argument( "--delete-ckpt-after-loading", @@ -1150,7 +1150,7 @@ def add_cli_args(parser: argparse.ArgumentParser): type=int, default=0, help="The number of shared_experts need to be replica to fuse with normal experts in deepseek v3/r1 " - "we use tp_size by default.", + "we use tp_size by default.", ) parser.add_argument( "--disable-shared-experts-fusion", @@ -1169,7 +1169,7 @@ def add_cli_args(parser: argparse.ArgumentParser): type=str, required=False, help="Specify custom warmup functions (csv) to run before server starts eg. --warmups=warmup_name1,warmup_name2 " - "will run the functions `warmup_name1` and `warmup_name2` specified in warmup.py before the server starts listening for requests", + "will run the functions `warmup_name1` and `warmup_name2` specified in warmup.py before the server starts listening for requests", ) # Debug tensor dumps From 4635582bdd07f82e13f69a4c65f329fbfc05a7c1 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 22:05:29 +0800 Subject: [PATCH 0492/1089] more --- test/srt/test_eplb.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100755 test/srt/test_eplb.py diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py new file mode 100755 index 00000000000..2295eee78c9 --- /dev/null +++ b/test/srt/test_eplb.py @@ -0,0 +1,21 @@ +import os +import unittest + +import requests +import torch + +from sglang.srt.utils import kill_process_tree +from sglang.test.test_utils import ( + DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + + +class TestEPLB(CustomTestCase): + TODO + +if __name__ == "__main__": + unittest.main() From 0e17b0292e298a641a23101568b841c64f0fb927 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 22:05:49 +0800 Subject: [PATCH 0493/1089] more --- test/srt/run_suite.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 4accef613cd..3d0417b201d 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -84,6 +84,7 @@ class TestFile: TestFile("test_data_parallelism.py", 90), TestFile("test_dp_attention.py", 90), TestFile("test_expert_distribution.py", 100), + TestFile("test_eplb.py", 100), TestFile("test_mla_tp.py", 420), TestFile("test_moe_ep.py", 220), TestFile("test_patch_torch.py", 30), From c3efb6c063a6bf1c824f2e0522d557fca52952a8 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 22:06:09 +0800 Subject: [PATCH 0494/1089] more --- test/srt/test_eplb.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index 2295eee78c9..f807f5dd925 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -1,9 +1,5 @@ -import os import unittest -import requests -import torch - from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST, @@ -15,7 +11,9 @@ class TestEPLB(CustomTestCase): - TODO + def test_eplb_e2e(self): + TODO + if __name__ == "__main__": unittest.main() From f6c5c3858284e4749eed214f1f859a57acc124c0 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 22:07:23 +0800 Subject: [PATCH 0495/1089] more --- test/srt/test_eplb.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index f807f5dd925..577cd089a84 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -2,17 +2,19 @@ from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( - DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST, + DEFAULT_MLA_MODEL_NAME_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, CustomTestCase, popen_launch_server, ) +import sglang as sgl class TestEPLB(CustomTestCase): def test_eplb_e2e(self): - TODO + engine = sgl.Engine(model_path=DEFAULT_MLA_MODEL_NAME_FOR_TEST) + engine.shutdown() if __name__ == "__main__": From adc3a036562a397440149a3b369c9347dff755af Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 22:07:57 +0800 Subject: [PATCH 0496/1089] more --- test/srt/test_eplb.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index 577cd089a84..3215b1a1df2 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -1,5 +1,6 @@ import unittest +import sglang as sgl from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( DEFAULT_MLA_MODEL_NAME_FOR_TEST, @@ -8,13 +9,24 @@ CustomTestCase, popen_launch_server, ) -import sglang as sgl class TestEPLB(CustomTestCase): def test_eplb_e2e(self): engine = sgl.Engine(model_path=DEFAULT_MLA_MODEL_NAME_FOR_TEST) + self._assert_behavior(engine) + + engine.shutdown() + del engine + + engine = sgl.Engine(model_path=DEFAULT_MLA_MODEL_NAME_FOR_TEST) + self._assert_behavior(engine) + engine.shutdown() + del engine + + def _assert_behavior(self, engine: sgl.Engine): + TODO if __name__ == "__main__": From 3af15faaf2b29a8737b71baa8620754faf85fb79 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 22:08:44 +0800 Subject: [PATCH 0497/1089] more --- test/srt/test_eplb.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index 3215b1a1df2..1a8907d762f 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -26,6 +26,7 @@ def test_eplb_e2e(self): del engine def _assert_behavior(self, engine: sgl.Engine): + engine.flu TODO From 72e7d4480aea36afe4851522d7d72da8b94f4825 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 22:10:10 +0800 Subject: [PATCH 0498/1089] more --- python/sglang/srt/entrypoints/engine.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 33aab232fe3..4e11e4b2e59 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -278,6 +278,10 @@ def __exit__(self, exc_type, exc_value, traceback): self.shutdown() return False + def flush_cache(self): + loop = asyncio.get_event_loop() + loop.run_until_complete(self.tokenizer_manager.flush_cache()) + def start_profile(self): loop = asyncio.get_event_loop() loop.run_until_complete(self.tokenizer_manager.start_profile()) From f3506bb3ea63c01db50c93bb8cf95dd8e258531c Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 22:12:25 +0800 Subject: [PATCH 0499/1089] rename --- python/sglang/srt/managers/scheduler.py | 6 +++--- python/sglang/srt/managers/tokenizer_manager.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 383cd680945..3ba027c7043 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -59,7 +59,7 @@ CloseSessionReqInput, ExpertDistributionReq, ExpertDistributionReqOutput, - FlushCacheReq, + FlushCacheReqInput, GetInternalStateReq, GetInternalStateReqOutput, GetWeightsByNameReqInput, @@ -400,7 +400,7 @@ def __init__( [ (TokenizedGenerateReqInput, self.handle_generate_request), (TokenizedEmbeddingReqInput, self.handle_embedding_request), - (FlushCacheReq, self.flush_cache_wrapped), + (FlushCacheReqInput, self.flush_cache_wrapped), (AbortReq, self.abort_request), (OpenSessionReqInput, self.open_session), (CloseSessionReqInput, self.close_session), @@ -1652,7 +1652,7 @@ def watchdog_thread(self): time.sleep(5) self.parent_process.send_signal(signal.SIGQUIT) - def flush_cache_wrapped(self, recv_req: FlushCacheReq): + def flush_cache_wrapped(self, recv_req: FlushCacheReqInput): self.flush_cache() def flush_cache(self): diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 33afffbd6de..358b4ae9fe2 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -62,7 +62,7 @@ EmbeddingReqInput, ExpertDistributionReq, ExpertDistributionReqOutput, - FlushCacheReq, + FlushCacheReqInput, GenerateReqInput, GetInternalStateReq, GetInternalStateReqOutput, @@ -617,7 +617,7 @@ async def _handle_batch_request( pass def flush_cache(self): - req = FlushCacheReq() + req = FlushCacheReqInput() self.send_to_scheduler.send_pyobj(req) def abort_request(self, rid: str): From 518b4fe553d347161fc77e83bc67cb8e447d82e2 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 22:12:44 +0800 Subject: [PATCH 0500/1089] more --- python/sglang/srt/managers/io_struct.py | 7 +++++- python/sglang/srt/managers/scheduler.py | 25 +++++++++---------- .../sglang/srt/managers/tokenizer_manager.py | 5 ++-- 3 files changed, 20 insertions(+), 17 deletions(-) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 00affa0a4ed..ca92fdf2cc1 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -665,7 +665,12 @@ class BatchEmbeddingOut: @dataclass -class FlushCacheReq: +class FlushCacheReqInput: + pass + + +@dataclass +class FlushCacheReqOutput: pass diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 3ba027c7043..9359107009a 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -32,8 +32,6 @@ import setproctitle import torch import zmq -from torch.distributed import barrier - from sglang.global_config import global_config from sglang.srt.configs.model_config import ModelConfig from sglang.srt.constrained.base_grammar_backend import create_grammar_backend @@ -131,6 +129,7 @@ suppress_other_loggers, ) from sglang.utils import TypeBasedDispatcher, get_exception_traceback +from torch.distributed import barrier expert_distribution_recorder = ExpertDistributionRecorder() @@ -371,8 +370,8 @@ def __init__( 1.0, ) self.new_token_ratio_decay = ( - self.init_new_token_ratio - self.min_new_token_ratio - ) / global_config.default_new_token_ratio_decay_steps + self.init_new_token_ratio - self.min_new_token_ratio + ) / global_config.default_new_token_ratio_decay_steps self.new_token_ratio = self.init_new_token_ratio # Init watchdog thread @@ -1260,10 +1259,10 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: if ( self.lora_paths and len( - lora_set - | set([req.lora_path for req in adder.can_run_list]) - | set([req.lora_path]) - ) + lora_set + | set([req.lora_path for req in adder.can_run_list]) + | set([req.lora_path]) + ) > self.max_loras_per_batch ): self.running_batch.batch_is_full = True @@ -1288,9 +1287,9 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: self.running_batch.batch_is_full = len( adder.can_run_list ) > 0 or ( - self.running_batch is not None - and not self.running_batch.is_empty() - ) + self.running_batch is not None + and not self.running_batch.is_empty() + ) else: self.running_batch.batch_is_full = True break @@ -1522,8 +1521,8 @@ def prepare_dp_attn_batch_raw( # We should have at least 1 token for sample in every case. max(extend_len - logprob_start_len, 1) for logprob_start_len, extend_len in zip( - local_batch.extend_logprob_start_lens, local_batch.extend_lens - ) + local_batch.extend_logprob_start_lens, local_batch.extend_lens + ) ] ) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 358b4ae9fe2..35b4e80b148 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -45,7 +45,6 @@ import zmq import zmq.asyncio from fastapi import BackgroundTasks - from sglang.srt.aio_rwlock import RWLock from sglang.srt.configs.model_config import ModelConfig from sglang.srt.disaggregation.conn import KVBootstrapServer @@ -959,8 +958,8 @@ def _handle_batch_output( elif isinstance(recv_obj, BatchTokenIDOut): if self.server_args.stream_output and state.obj.stream: output_token_ids = recv_obj.output_ids[i][ - state.last_output_offset : - ] + state.last_output_offset: + ] state.last_output_offset = len(recv_obj.output_ids[i]) else: output_token_ids = recv_obj.output_ids[i] From 3db8c6125f1655773da9066f39b88e2a13e585a6 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 22:13:31 +0800 Subject: [PATCH 0501/1089] more --- python/sglang/srt/managers/io_struct.py | 2 +- python/sglang/srt/managers/scheduler.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index ca92fdf2cc1..d2d5c591a2d 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -671,7 +671,7 @@ class FlushCacheReqInput: @dataclass class FlushCacheReqOutput: - pass + success: bool @dataclass diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 9359107009a..e7a0d3e6a68 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -1652,7 +1652,8 @@ def watchdog_thread(self): self.parent_process.send_signal(signal.SIGQUIT) def flush_cache_wrapped(self, recv_req: FlushCacheReqInput): - self.flush_cache() + success = self.flush_cache() + return FlushCacheReqOutput(success=success) def flush_cache(self): """Flush the memory pool and cache.""" From 14b7d6e788d5031b9a0143eac559d8744c03125c Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 22:13:51 +0800 Subject: [PATCH 0502/1089] more --- python/sglang/srt/managers/scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index e7a0d3e6a68..caaba2e93b3 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -85,7 +85,7 @@ UpdateWeightsFromDistributedReqInput, UpdateWeightsFromDistributedReqOutput, UpdateWeightsFromTensorReqInput, - UpdateWeightsFromTensorReqOutput, + UpdateWeightsFromTensorReqOutput, FlushCacheReqOutput, ) from sglang.srt.managers.schedule_batch import ( FINISH_ABORT, From 4ac95af6924ff1781650c391fdeda787dff52ab0 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 22:14:39 +0800 Subject: [PATCH 0503/1089] more --- python/sglang/srt/managers/tokenizer_manager.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 35b4e80b148..e59c42bd731 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -87,7 +87,7 @@ UpdateWeightsFromDistributedReqInput, UpdateWeightsFromDistributedReqOutput, UpdateWeightsFromTensorReqInput, - UpdateWeightsFromTensorReqOutput, + UpdateWeightsFromTensorReqOutput, FlushCacheReqOutput, ) from sglang.srt.managers.multimodal_processor import ( get_dummy_processor, @@ -257,6 +257,9 @@ def __init__( self.resume_memory_occupation_communicator = _Communicator( self.send_to_scheduler, server_args.dp_size ) + self.flush_cache_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) self.start_profile_communicator = _Communicator( self.send_to_scheduler, server_args.dp_size ) @@ -307,6 +310,10 @@ def __init__( ResumeMemoryOccupationReqOutput, self.resume_memory_occupation_communicator.handle_recv, ), + ( + FlushCacheReqOutput, + self.flush_cache_communicator.handle_recv, + ), ( ProfileReqOutput, self.start_profile_communicator.handle_recv, From 5b9d05fb9d16b5658c1c7ba697348d7a68c68649 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 22:14:59 +0800 Subject: [PATCH 0504/1089] more --- python/sglang/srt/managers/tokenizer_manager.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index e59c42bd731..9d2018b42af 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -622,9 +622,8 @@ async def _handle_batch_request( except StopAsyncIteration: pass - def flush_cache(self): - req = FlushCacheReqInput() - self.send_to_scheduler.send_pyobj(req) + async def flush_cache(self): + return await self.flush_cache_communicator(FlushCacheReqInput()) def abort_request(self, rid: str): if rid not in self.rid_to_state: From 46f079e1dd5bebf8b57a7e5f07efda8449a94146 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 22:15:13 +0800 Subject: [PATCH 0505/1089] more --- python/sglang/srt/entrypoints/http_server.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 1f93b475c26..9bc8286789c 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -310,10 +310,10 @@ async def classify_request(obj: EmbeddingReqInput, request: Request): @app.api_route("/flush_cache", methods=["GET", "POST"]) async def flush_cache(): """Flush the radix cache.""" - _global_state.tokenizer_manager.flush_cache() + await _global_state.tokenizer_manager.flush_cache() return Response( content="Cache flushed.\nPlease check backend logs for more details. " - "(When there are running or waiting requests, the operation will not be performed.)\n", + "(When there are running or waiting requests, the operation will not be performed.)\n", status_code=200, ) @@ -637,10 +637,10 @@ async def vertex_generate(vertex_req: VertexGenerateReqInput, raw_request: Reque ] break image_data = [ - instance.get("image_data") - for instance in vertex_req.instances - if instance.get("image_data") is not None - ] or None + instance.get("image_data") + for instance in vertex_req.instances + if instance.get("image_data") is not None + ] or None req = GenerateReqInput( **inputs, image_data=image_data, From c9a6c4192636b37e73afd2c34a8ebd311b740fcf Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 22:15:32 +0800 Subject: [PATCH 0506/1089] fmt --- python/sglang/srt/entrypoints/http_server.py | 10 +++---- python/sglang/srt/managers/scheduler.py | 28 ++++++++++--------- .../sglang/srt/managers/tokenizer_manager.py | 8 ++++-- 3 files changed, 25 insertions(+), 21 deletions(-) diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 9bc8286789c..11ded8126db 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -313,7 +313,7 @@ async def flush_cache(): await _global_state.tokenizer_manager.flush_cache() return Response( content="Cache flushed.\nPlease check backend logs for more details. " - "(When there are running or waiting requests, the operation will not be performed.)\n", + "(When there are running or waiting requests, the operation will not be performed.)\n", status_code=200, ) @@ -637,10 +637,10 @@ async def vertex_generate(vertex_req: VertexGenerateReqInput, raw_request: Reque ] break image_data = [ - instance.get("image_data") - for instance in vertex_req.instances - if instance.get("image_data") is not None - ] or None + instance.get("image_data") + for instance in vertex_req.instances + if instance.get("image_data") is not None + ] or None req = GenerateReqInput( **inputs, image_data=image_data, diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index caaba2e93b3..24f5d4aa1af 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -32,6 +32,8 @@ import setproctitle import torch import zmq +from torch.distributed import barrier + from sglang.global_config import global_config from sglang.srt.configs.model_config import ModelConfig from sglang.srt.constrained.base_grammar_backend import create_grammar_backend @@ -58,6 +60,7 @@ ExpertDistributionReq, ExpertDistributionReqOutput, FlushCacheReqInput, + FlushCacheReqOutput, GetInternalStateReq, GetInternalStateReqOutput, GetWeightsByNameReqInput, @@ -85,7 +88,7 @@ UpdateWeightsFromDistributedReqInput, UpdateWeightsFromDistributedReqOutput, UpdateWeightsFromTensorReqInput, - UpdateWeightsFromTensorReqOutput, FlushCacheReqOutput, + UpdateWeightsFromTensorReqOutput, ) from sglang.srt.managers.schedule_batch import ( FINISH_ABORT, @@ -129,7 +132,6 @@ suppress_other_loggers, ) from sglang.utils import TypeBasedDispatcher, get_exception_traceback -from torch.distributed import barrier expert_distribution_recorder = ExpertDistributionRecorder() @@ -370,8 +372,8 @@ def __init__( 1.0, ) self.new_token_ratio_decay = ( - self.init_new_token_ratio - self.min_new_token_ratio - ) / global_config.default_new_token_ratio_decay_steps + self.init_new_token_ratio - self.min_new_token_ratio + ) / global_config.default_new_token_ratio_decay_steps self.new_token_ratio = self.init_new_token_ratio # Init watchdog thread @@ -1259,10 +1261,10 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: if ( self.lora_paths and len( - lora_set - | set([req.lora_path for req in adder.can_run_list]) - | set([req.lora_path]) - ) + lora_set + | set([req.lora_path for req in adder.can_run_list]) + | set([req.lora_path]) + ) > self.max_loras_per_batch ): self.running_batch.batch_is_full = True @@ -1287,9 +1289,9 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: self.running_batch.batch_is_full = len( adder.can_run_list ) > 0 or ( - self.running_batch is not None - and not self.running_batch.is_empty() - ) + self.running_batch is not None + and not self.running_batch.is_empty() + ) else: self.running_batch.batch_is_full = True break @@ -1521,8 +1523,8 @@ def prepare_dp_attn_batch_raw( # We should have at least 1 token for sample in every case. max(extend_len - logprob_start_len, 1) for logprob_start_len, extend_len in zip( - local_batch.extend_logprob_start_lens, local_batch.extend_lens - ) + local_batch.extend_logprob_start_lens, local_batch.extend_lens + ) ] ) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 9d2018b42af..ca89361e47d 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -45,6 +45,7 @@ import zmq import zmq.asyncio from fastapi import BackgroundTasks + from sglang.srt.aio_rwlock import RWLock from sglang.srt.configs.model_config import ModelConfig from sglang.srt.disaggregation.conn import KVBootstrapServer @@ -62,6 +63,7 @@ ExpertDistributionReq, ExpertDistributionReqOutput, FlushCacheReqInput, + FlushCacheReqOutput, GenerateReqInput, GetInternalStateReq, GetInternalStateReqOutput, @@ -87,7 +89,7 @@ UpdateWeightsFromDistributedReqInput, UpdateWeightsFromDistributedReqOutput, UpdateWeightsFromTensorReqInput, - UpdateWeightsFromTensorReqOutput, FlushCacheReqOutput, + UpdateWeightsFromTensorReqOutput, ) from sglang.srt.managers.multimodal_processor import ( get_dummy_processor, @@ -964,8 +966,8 @@ def _handle_batch_output( elif isinstance(recv_obj, BatchTokenIDOut): if self.server_args.stream_output and state.obj.stream: output_token_ids = recv_obj.output_ids[i][ - state.last_output_offset: - ] + state.last_output_offset : + ] state.last_output_offset = len(recv_obj.output_ids[i]) else: output_token_ids = recv_obj.output_ids[i] From 8b413318494fdffaed9d5221b04456f11540f4ac Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 22:17:14 +0800 Subject: [PATCH 0507/1089] more --- python/sglang/srt/entrypoints/http_server.py | 4 ++-- python/sglang/srt/managers/tokenizer_manager.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 11ded8126db..2e5944a70b9 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -310,11 +310,11 @@ async def classify_request(obj: EmbeddingReqInput, request: Request): @app.api_route("/flush_cache", methods=["GET", "POST"]) async def flush_cache(): """Flush the radix cache.""" - await _global_state.tokenizer_manager.flush_cache() + ret = await _global_state.tokenizer_manager.flush_cache() return Response( content="Cache flushed.\nPlease check backend logs for more details. " "(When there are running or waiting requests, the operation will not be performed.)\n", - status_code=200, + status_code=200 if ret.success else HTTPStatus.BAD_REQUEST, ) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index ca89361e47d..93e74ad27fc 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -624,7 +624,7 @@ async def _handle_batch_request( except StopAsyncIteration: pass - async def flush_cache(self): + async def flush_cache(self) -> FlushCacheReqOutput: return await self.flush_cache_communicator(FlushCacheReqInput()) def abort_request(self, rid: str): From 2d16f97f26a09b103783ddd52f987808fb551b32 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 22:18:09 +0800 Subject: [PATCH 0508/1089] more --- python/sglang/srt/entrypoints/engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 4e11e4b2e59..bf7371fff68 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -280,7 +280,7 @@ def __exit__(self, exc_type, exc_value, traceback): def flush_cache(self): loop = asyncio.get_event_loop() - loop.run_until_complete(self.tokenizer_manager.flush_cache()) + return loop.run_until_complete(self.tokenizer_manager.flush_cache()) def start_profile(self): loop = asyncio.get_event_loop() From 748893521e069f2e9de2178ca9084fe41589cd83 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 22:19:20 +0800 Subject: [PATCH 0509/1089] more --- test/srt/test_eplb.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index 1a8907d762f..c4d2ed6be8b 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -26,7 +26,9 @@ def test_eplb_e2e(self): del engine def _assert_behavior(self, engine: sgl.Engine): - engine.flu + ret = engine.flush_cache() + assert ret.success + TODO From 275c0e2ab799f70cc0efd10ae07712f6e615cb01 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 22:20:12 +0800 Subject: [PATCH 0510/1089] more --- test/srt/test_eplb.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index c4d2ed6be8b..88f7a487020 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -13,13 +13,18 @@ class TestEPLB(CustomTestCase): def test_eplb_e2e(self): - engine = sgl.Engine(model_path=DEFAULT_MLA_MODEL_NAME_FOR_TEST) + engine_kwargs = dict( + model_path=DEFAULT_MLA_MODEL_NAME_FOR_TEST, + eplb_storage_dir=TODO, + ) + + engine = sgl.Engine( **engine_kwargs) self._assert_behavior(engine) engine.shutdown() del engine - engine = sgl.Engine(model_path=DEFAULT_MLA_MODEL_NAME_FOR_TEST) + engine = sgl.Engine( **engine_kwargs) self._assert_behavior(engine) engine.shutdown() From 9c126360d8f51e7ab4eb017031da99144c281a1c Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 22:20:37 +0800 Subject: [PATCH 0511/1089] more --- test/srt/test_eplb.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index 88f7a487020..a6d6da46657 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -1,3 +1,4 @@ +import tempfile import unittest import sglang as sgl @@ -13,22 +14,23 @@ class TestEPLB(CustomTestCase): def test_eplb_e2e(self): - engine_kwargs = dict( - model_path=DEFAULT_MLA_MODEL_NAME_FOR_TEST, - eplb_storage_dir=TODO, - ) + with tempfile.TemporaryDirectory() as tmpdir: + engine_kwargs = dict( + model_path=DEFAULT_MLA_MODEL_NAME_FOR_TEST, + eplb_storage_dir=tmpdir, + ) - engine = sgl.Engine( **engine_kwargs) - self._assert_behavior(engine) + engine = sgl.Engine( **engine_kwargs) + self._assert_behavior(engine) - engine.shutdown() - del engine + engine.shutdown() + del engine - engine = sgl.Engine( **engine_kwargs) - self._assert_behavior(engine) + engine = sgl.Engine( **engine_kwargs) + self._assert_behavior(engine) - engine.shutdown() - del engine + engine.shutdown() + del engine def _assert_behavior(self, engine: sgl.Engine): ret = engine.flush_cache() From 149781a5491f13c0f9ee3bfdab21693929a310fc Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 22:21:33 +0800 Subject: [PATCH 0512/1089] more --- test/srt/test_eplb.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index a6d6da46657..9c76b6a8e18 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -20,13 +20,13 @@ def test_eplb_e2e(self): eplb_storage_dir=tmpdir, ) - engine = sgl.Engine( **engine_kwargs) + engine = sgl.Engine(**engine_kwargs) self._assert_behavior(engine) engine.shutdown() del engine - engine = sgl.Engine( **engine_kwargs) + engine = sgl.Engine(**engine_kwargs) self._assert_behavior(engine) engine.shutdown() @@ -36,6 +36,8 @@ def _assert_behavior(self, engine: sgl.Engine): ret = engine.flush_cache() assert ret.success + engine.generate(prompt=["1+1=2, 2+2=4", "One plus one is two, two plus two is four"]) + TODO From 507a44ff9b228abd102aa4871ffca9953872b462 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 22:22:39 +0800 Subject: [PATCH 0513/1089] more --- test/srt/test_eplb.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index 9c76b6a8e18..af9f8c9e9f7 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -36,7 +36,8 @@ def _assert_behavior(self, engine: sgl.Engine): ret = engine.flush_cache() assert ret.success - engine.generate(prompt=["1+1=2, 2+2=4", "One plus one is two, two plus two is four"]) + output = engine.generate(prompt=["1+1=2, 2+2=4", "One plus one is two, two plus two is four"], + sampling_params=dict(max_new_tokens=8, temperature=0.0)) TODO From e0bcfa3f588aa47f65627db5c92d72eba268ab33 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 22:25:18 +0800 Subject: [PATCH 0514/1089] more --- test/srt/test_eplb.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index af9f8c9e9f7..0db2f964728 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -1,5 +1,6 @@ import tempfile import unittest +from typing import List import sglang as sgl from sglang.srt.utils import kill_process_tree @@ -21,26 +22,33 @@ def test_eplb_e2e(self): ) engine = sgl.Engine(**engine_kwargs) - self._assert_behavior(engine) + ref_output = self._engine_generate(engine) + self._assert_behavior(engine, ref_output) engine.shutdown() del engine engine = sgl.Engine(**engine_kwargs) - self._assert_behavior(engine) + self._assert_behavior(engine, ref_output) engine.shutdown() del engine - def _assert_behavior(self, engine: sgl.Engine): + def _assert_behavior(self, engine: sgl.Engine, ref_output: List[str]): ret = engine.flush_cache() assert ret.success - output = engine.generate(prompt=["1+1=2, 2+2=4", "One plus one is two, two plus two is four"], - sampling_params=dict(max_new_tokens=8, temperature=0.0)) + actual_output = self._engine_generate(engine) + self.assertEqual(actual_output, ref_output) TODO + def _engine_generate(self, engine: sgl.Engine): + output = engine.generate(prompt=["1+1=2, 2+2=4", "One plus one is two, two plus two is four"], + sampling_params=dict(max_new_tokens=8, temperature=0.0)) + print(f"engine_generate {output=}") + return [x["text"] for x in output] + if __name__ == "__main__": unittest.main() From 1da2438ad04a8f567adb796b06520a19aa96c19e Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 22:25:48 +0800 Subject: [PATCH 0515/1089] more --- test/srt/test_eplb.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index 0db2f964728..5840c5e35e2 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -25,12 +25,18 @@ def test_eplb_e2e(self): ref_output = self._engine_generate(engine) self._assert_behavior(engine, ref_output) + engine.eplb_rebalance() + self._assert_behavior(engine, ref_output) + engine.shutdown() del engine engine = sgl.Engine(**engine_kwargs) self._assert_behavior(engine, ref_output) + engine.eplb_rebalance() + self._assert_behavior(engine, ref_output) + engine.shutdown() del engine From a8a8ffe12de69fcbb9cb892ba8be17ebee3718c9 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 22:27:59 +0800 Subject: [PATCH 0516/1089] more --- python/sglang/srt/entrypoints/engine.py | 4 ++++ python/sglang/srt/managers/tokenizer_manager.py | 9 ++++++--- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index dbd11f382c9..6d4408b927c 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -360,6 +360,10 @@ def update_weights_from_tensor( self.tokenizer_manager.update_weights_from_tensor(obj, None) ) + def eplb_rebalance(self): + loop = asyncio.get_event_loop() + return loop.run_until_complete(self.tokenizer_manager.eplb_rebalance()) + def update_expert_location(self, expert_location_metadata: ExpertLocationMetadata): obj = UpdateExpertLocationReqInput( expert_location_metadata=expert_location_metadata, diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 200b9bfd65c..53f1d4e2c51 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -45,7 +45,6 @@ import zmq import zmq.asyncio from fastapi import BackgroundTasks - from sglang.srt.aio_rwlock import RWLock from sglang.srt.configs.model_config import ModelConfig from sglang.srt.disaggregation.conn import KVBootstrapServer @@ -686,6 +685,10 @@ async def dump_expert_distribution_record(self): expert_location_metadata=self.expert_location_metadata, ) + async def eplb_rebalance(self): + self.auto_create_handle_loop() + await self.eplb_manager.rebalance() + async def update_expert_location(self, obj: UpdateExpertLocationReqInput): self.auto_create_handle_loop() assert ( @@ -1009,8 +1012,8 @@ def _handle_batch_output( elif isinstance(recv_obj, BatchTokenIDOut): if self.server_args.stream_output and state.obj.stream: output_token_ids = recv_obj.output_ids[i][ - state.last_output_offset : - ] + state.last_output_offset: + ] state.last_output_offset = len(recv_obj.output_ids[i]) else: output_token_ids = recv_obj.output_ids[i] From 6e3e673a3b0851784cb547c5955e6321f9ee8080 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 22:28:16 +0800 Subject: [PATCH 0517/1089] more --- python/sglang/srt/entrypoints/http_server.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index f43afec5423..fe729290880 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -313,7 +313,7 @@ async def flush_cache(): _global_state.tokenizer_manager.flush_cache() return Response( content="Cache flushed.\nPlease check backend logs for more details. " - "(When there are running or waiting requests, the operation will not be performed.)\n", + "(When there are running or waiting requests, the operation will not be performed.)\n", status_code=200, ) @@ -370,6 +370,12 @@ async def dump_expert_distribution_record_async(): return ORJSONResponse(content, status_code=200) +@app.post("/eplb_rebalance") +async def eplb_rebalance(): + await _global_state.tokenizer_manager.eplb_rebalance() + return ORJSONResponse({}, status_code=200) + + @app.post("/update_weights_from_disk") async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: Request): """Update the weights from disk inplace without re-launching the server.""" @@ -634,10 +640,10 @@ async def vertex_generate(vertex_req: VertexGenerateReqInput, raw_request: Reque ] break image_data = [ - instance.get("image_data") - for instance in vertex_req.instances - if instance.get("image_data") is not None - ] or None + instance.get("image_data") + for instance in vertex_req.instances + if instance.get("image_data") is not None + ] or None req = GenerateReqInput( **inputs, image_data=image_data, From 78b9e64858246e7e4dba1a89e1920c1d5c5e5883 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 22:28:39 +0800 Subject: [PATCH 0518/1089] fmt --- python/sglang/srt/entrypoints/http_server.py | 10 +++++----- python/sglang/srt/managers/tokenizer_manager.py | 5 +++-- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index fe729290880..f0a6f946bb4 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -313,7 +313,7 @@ async def flush_cache(): _global_state.tokenizer_manager.flush_cache() return Response( content="Cache flushed.\nPlease check backend logs for more details. " - "(When there are running or waiting requests, the operation will not be performed.)\n", + "(When there are running or waiting requests, the operation will not be performed.)\n", status_code=200, ) @@ -640,10 +640,10 @@ async def vertex_generate(vertex_req: VertexGenerateReqInput, raw_request: Reque ] break image_data = [ - instance.get("image_data") - for instance in vertex_req.instances - if instance.get("image_data") is not None - ] or None + instance.get("image_data") + for instance in vertex_req.instances + if instance.get("image_data") is not None + ] or None req = GenerateReqInput( **inputs, image_data=image_data, diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 53f1d4e2c51..67df63762db 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -45,6 +45,7 @@ import zmq import zmq.asyncio from fastapi import BackgroundTasks + from sglang.srt.aio_rwlock import RWLock from sglang.srt.configs.model_config import ModelConfig from sglang.srt.disaggregation.conn import KVBootstrapServer @@ -1012,8 +1013,8 @@ def _handle_batch_output( elif isinstance(recv_obj, BatchTokenIDOut): if self.server_args.stream_output and state.obj.stream: output_token_ids = recv_obj.output_ids[i][ - state.last_output_offset: - ] + state.last_output_offset : + ] state.last_output_offset = len(recv_obj.output_ids[i]) else: output_token_ids = recv_obj.output_ids[i] From fe202f265a111a9e97214172313075f8caa96770 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 22:29:56 +0800 Subject: [PATCH 0519/1089] more --- test/srt/test_eplb.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index 5840c5e35e2..e833a53f1bc 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -47,6 +47,7 @@ def _assert_behavior(self, engine: sgl.Engine, ref_output: List[str]): actual_output = self._engine_generate(engine) self.assertEqual(actual_output, ref_output) + physical_to_logical_map = engine.tokenizer_manager.expert_location_metadata.physical_to_logical_map TODO def _engine_generate(self, engine: sgl.Engine): From b080880f7ad969c1108c5ce3f7045fbfddec4672 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 22:30:04 +0800 Subject: [PATCH 0520/1089] more --- test/srt/test_eplb.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index e833a53f1bc..7d993c20315 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -19,6 +19,7 @@ def test_eplb_e2e(self): engine_kwargs = dict( model_path=DEFAULT_MLA_MODEL_NAME_FOR_TEST, eplb_storage_dir=tmpdir, + tp_size=2, ) engine = sgl.Engine(**engine_kwargs) From 07fcc4dd62da0d63a94964e8c7a34315a98b4d98 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 22:30:48 +0800 Subject: [PATCH 0521/1089] more --- test/srt/test_eplb.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index 7d993c20315..20ada29f5f9 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -3,6 +3,7 @@ from typing import List import sglang as sgl +import torch from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( DEFAULT_MLA_MODEL_NAME_FOR_TEST, @@ -15,6 +16,8 @@ class TestEPLB(CustomTestCase): def test_eplb_e2e(self): + TODO_test_redundant_experts + with tempfile.TemporaryDirectory() as tmpdir: engine_kwargs = dict( model_path=DEFAULT_MLA_MODEL_NAME_FOR_TEST, @@ -49,6 +52,7 @@ def _assert_behavior(self, engine: sgl.Engine, ref_output: List[str]): self.assertEqual(actual_output, ref_output) physical_to_logical_map = engine.tokenizer_manager.expert_location_metadata.physical_to_logical_map + assert torch.all(physical_to_logical_map[0, :] == TODO) TODO def _engine_generate(self, engine: sgl.Engine): From cf6b689d009d115f5298ef4019993249c4d2e02f Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 22:32:16 +0800 Subject: [PATCH 0522/1089] more --- test/srt/test_eplb.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index 20ada29f5f9..461ebd29b5d 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -13,15 +13,18 @@ popen_launch_server, ) +# P.S. DeepSeek-Coder-V2-Lite-Instruct has 64 routed experts +_EP_NUM_REDUNDANT_EXPERTS = 4 + class TestEPLB(CustomTestCase): def test_eplb_e2e(self): - TODO_test_redundant_experts - with tempfile.TemporaryDirectory() as tmpdir: engine_kwargs = dict( model_path=DEFAULT_MLA_MODEL_NAME_FOR_TEST, + enable_eplb=True, eplb_storage_dir=tmpdir, + ep_num_redundant_experts=_EP_NUM_REDUNDANT_EXPERTS, tp_size=2, ) From 5ef94f20e05db5d758df03e2200714635953bbd9 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 22:33:12 +0800 Subject: [PATCH 0523/1089] more --- test/srt/test_eplb.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index 461ebd29b5d..8ea90bd019c 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -3,7 +3,6 @@ from typing import List import sglang as sgl -import torch from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( DEFAULT_MLA_MODEL_NAME_FOR_TEST, @@ -28,22 +27,24 @@ def test_eplb_e2e(self): tp_size=2, ) + print(f"Action: start engine") engine = sgl.Engine(**engine_kwargs) ref_output = self._engine_generate(engine) self._assert_behavior(engine, ref_output) + print(f"Action: eplb_rebalance") engine.eplb_rebalance() self._assert_behavior(engine, ref_output) + print(f"Action: shutdown engine") engine.shutdown() del engine + print(f"Action: start engine") engine = sgl.Engine(**engine_kwargs) self._assert_behavior(engine, ref_output) - engine.eplb_rebalance() - self._assert_behavior(engine, ref_output) - + print(f"Action: shutdown engine") engine.shutdown() del engine @@ -55,8 +56,8 @@ def _assert_behavior(self, engine: sgl.Engine, ref_output: List[str]): self.assertEqual(actual_output, ref_output) physical_to_logical_map = engine.tokenizer_manager.expert_location_metadata.physical_to_logical_map - assert torch.all(physical_to_logical_map[0, :] == TODO) - TODO + physical_to_logical_map_layer_0 = physical_to_logical_map[0, :].tolist() + print(f"{physical_to_logical_map_layer_0=}") def _engine_generate(self, engine: sgl.Engine): output = engine.generate(prompt=["1+1=2, 2+2=4", "One plus one is two, two plus two is four"], From b74588d7b315c38dace03b64e86a00c818ee0a5c Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 22:34:48 +0800 Subject: [PATCH 0524/1089] more --- test/srt/test_eplb.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index 8ea90bd019c..113faad3264 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -30,11 +30,11 @@ def test_eplb_e2e(self): print(f"Action: start engine") engine = sgl.Engine(**engine_kwargs) ref_output = self._engine_generate(engine) - self._assert_behavior(engine, ref_output) + self._assert_behavior(engine, ref_output, "equal_trivial") print(f"Action: eplb_rebalance") - engine.eplb_rebalance() - self._assert_behavior(engine, ref_output) + physical_to_logical_map_layer_0_after_first_rebalance = engine.eplb_rebalance() + self._assert_behavior(engine, ref_output, "not_equal_trivial") print(f"Action: shutdown engine") engine.shutdown() @@ -42,13 +42,13 @@ def test_eplb_e2e(self): print(f"Action: start engine") engine = sgl.Engine(**engine_kwargs) - self._assert_behavior(engine, ref_output) + self._assert_behavior(engine, ref_output, physical_to_logical_map_layer_0_after_first_rebalance) print(f"Action: shutdown engine") engine.shutdown() del engine - def _assert_behavior(self, engine: sgl.Engine, ref_output: List[str]): + def _assert_behavior(self, engine: sgl.Engine, ref_output: List[str], expect_physical_to_local_map): ret = engine.flush_cache() assert ret.success @@ -59,6 +59,15 @@ def _assert_behavior(self, engine: sgl.Engine, ref_output: List[str]): physical_to_logical_map_layer_0 = physical_to_logical_map[0, :].tolist() print(f"{physical_to_logical_map_layer_0=}") + if expect_physical_to_local_map == "equal_trivial": + self.assertEqual(physical_to_logical_map_layer_0, list(range(TODO))) + elif expect_physical_to_local_map == "not_equal_trivial": + self.assertNotEqual(physical_to_logical_map_layer_0, list(range(TODO))) + else: + self.assertEqual(physical_to_logical_map_layer_0, expect_physical_to_local_map) + + return physical_to_logical_map_layer_0 + def _engine_generate(self, engine: sgl.Engine): output = engine.generate(prompt=["1+1=2, 2+2=4", "One plus one is two, two plus two is four"], sampling_params=dict(max_new_tokens=8, temperature=0.0)) From c73c9270f5de616c20d49ab2d26985cb2f29dac6 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 22:35:32 +0800 Subject: [PATCH 0525/1089] more --- test/srt/test_eplb.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index 113faad3264..0423ac01e46 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -12,8 +12,9 @@ popen_launch_server, ) -# P.S. DeepSeek-Coder-V2-Lite-Instruct has 64 routed experts +_NUM_ROUTED_EXPERTS = 64 # DeepSeek-Coder-V2-Lite-Instruct _EP_NUM_REDUNDANT_EXPERTS = 4 +_NUM_OVERALL_PHYSICAL_EXPERTS = _NUM_ROUTED_EXPERTS + _EP_NUM_REDUNDANT_EXPERTS class TestEPLB(CustomTestCase): @@ -60,9 +61,9 @@ def _assert_behavior(self, engine: sgl.Engine, ref_output: List[str], expect_phy print(f"{physical_to_logical_map_layer_0=}") if expect_physical_to_local_map == "equal_trivial": - self.assertEqual(physical_to_logical_map_layer_0, list(range(TODO))) + self.assertEqual(physical_to_logical_map_layer_0, list(range(_NUM_OVERALL_PHYSICAL_EXPERTS))) elif expect_physical_to_local_map == "not_equal_trivial": - self.assertNotEqual(physical_to_logical_map_layer_0, list(range(TODO))) + self.assertNotEqual(physical_to_logical_map_layer_0, list(range(_NUM_OVERALL_PHYSICAL_EXPERTS))) else: self.assertEqual(physical_to_logical_map_layer_0, expect_physical_to_local_map) From 83950be2c2f2498e7aa043ca2c762feb90c7b9e8 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 10 Apr 2025 22:35:57 +0800 Subject: [PATCH 0526/1089] fmt --- test/srt/test_eplb.py | 38 +++++++++++++++++++++++++++++--------- 1 file changed, 29 insertions(+), 9 deletions(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index 0423ac01e46..bda3bc9bd3d 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -34,7 +34,9 @@ def test_eplb_e2e(self): self._assert_behavior(engine, ref_output, "equal_trivial") print(f"Action: eplb_rebalance") - physical_to_logical_map_layer_0_after_first_rebalance = engine.eplb_rebalance() + physical_to_logical_map_layer_0_after_first_rebalance = ( + engine.eplb_rebalance() + ) self._assert_behavior(engine, ref_output, "not_equal_trivial") print(f"Action: shutdown engine") @@ -43,35 +45,53 @@ def test_eplb_e2e(self): print(f"Action: start engine") engine = sgl.Engine(**engine_kwargs) - self._assert_behavior(engine, ref_output, physical_to_logical_map_layer_0_after_first_rebalance) + self._assert_behavior( + engine, + ref_output, + physical_to_logical_map_layer_0_after_first_rebalance, + ) print(f"Action: shutdown engine") engine.shutdown() del engine - def _assert_behavior(self, engine: sgl.Engine, ref_output: List[str], expect_physical_to_local_map): + def _assert_behavior( + self, engine: sgl.Engine, ref_output: List[str], expect_physical_to_local_map + ): ret = engine.flush_cache() assert ret.success actual_output = self._engine_generate(engine) self.assertEqual(actual_output, ref_output) - physical_to_logical_map = engine.tokenizer_manager.expert_location_metadata.physical_to_logical_map + physical_to_logical_map = ( + engine.tokenizer_manager.expert_location_metadata.physical_to_logical_map + ) physical_to_logical_map_layer_0 = physical_to_logical_map[0, :].tolist() print(f"{physical_to_logical_map_layer_0=}") if expect_physical_to_local_map == "equal_trivial": - self.assertEqual(physical_to_logical_map_layer_0, list(range(_NUM_OVERALL_PHYSICAL_EXPERTS))) + self.assertEqual( + physical_to_logical_map_layer_0, + list(range(_NUM_OVERALL_PHYSICAL_EXPERTS)), + ) elif expect_physical_to_local_map == "not_equal_trivial": - self.assertNotEqual(physical_to_logical_map_layer_0, list(range(_NUM_OVERALL_PHYSICAL_EXPERTS))) + self.assertNotEqual( + physical_to_logical_map_layer_0, + list(range(_NUM_OVERALL_PHYSICAL_EXPERTS)), + ) else: - self.assertEqual(physical_to_logical_map_layer_0, expect_physical_to_local_map) + self.assertEqual( + physical_to_logical_map_layer_0, expect_physical_to_local_map + ) return physical_to_logical_map_layer_0 def _engine_generate(self, engine: sgl.Engine): - output = engine.generate(prompt=["1+1=2, 2+2=4", "One plus one is two, two plus two is four"], - sampling_params=dict(max_new_tokens=8, temperature=0.0)) + output = engine.generate( + prompt=["1+1=2, 2+2=4", "One plus one is two, two plus two is four"], + sampling_params=dict(max_new_tokens=8, temperature=0.0), + ) print(f"engine_generate {output=}") return [x["text"] for x in output] From 452524adfe45c3464fc0a0c9227e78d504885376 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 09:33:31 +0800 Subject: [PATCH 0527/1089] more --- python/sglang/srt/managers/tokenizer_manager.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 91c58b6119b..64b3cc5db65 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -697,6 +697,10 @@ async def eplb_rebalance(self): self.auto_create_handle_loop() await self.eplb_manager.rebalance() + async def eplb_save_expert_distribution(self): + self.auto_create_handle_loop() + self.eplb_manager.save_expert_distribution() + async def update_expert_location(self, obj: UpdateExpertLocationReqInput): self.auto_create_handle_loop() assert ( From 3bcf2502c54ad6e7f53c25b2b779e79318ef1a58 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 09:34:00 +0800 Subject: [PATCH 0528/1089] more --- python/sglang/srt/entrypoints/engine.py | 5 ++++- python/sglang/srt/entrypoints/http_server.py | 16 +++++++++++----- python/sglang/srt/managers/tokenizer_manager.py | 5 ++--- 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index cc650266253..8cc44807703 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -31,7 +31,6 @@ import zmq import zmq.asyncio from PIL.Image import Image - from sglang.srt.configs.model_config import ModelConfig from sglang.srt.managers.eplb_manager import EPLBManager from sglang.srt.managers.expert_location import ExpertLocationMetadata @@ -368,6 +367,10 @@ def eplb_rebalance(self): loop = asyncio.get_event_loop() return loop.run_until_complete(self.tokenizer_manager.eplb_rebalance()) + def eplb_save_expert_distribution(self): + loop = asyncio.get_event_loop() + return loop.run_until_complete(self.tokenizer_manager.eplb_save_expert_distribution()) + def update_expert_location(self, expert_location_metadata: ExpertLocationMetadata): obj = UpdateExpertLocationReqInput( expert_location_metadata=expert_location_metadata, diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index d0c455ea5e6..a9aa47d2e27 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -313,7 +313,7 @@ async def flush_cache(): ret = await _global_state.tokenizer_manager.flush_cache() return Response( content="Cache flushed.\nPlease check backend logs for more details. " - "(When there are running or waiting requests, the operation will not be performed.)\n", + "(When there are running or waiting requests, the operation will not be performed.)\n", status_code=200 if ret.success else HTTPStatus.BAD_REQUEST, ) @@ -376,6 +376,12 @@ async def eplb_rebalance(): return ORJSONResponse({}, status_code=200) +@app.post("/eplb_save_expert_distribution") +async def eplb_save_expert_distribution(): + await _global_state.tokenizer_manager.eplb_save_expert_distribution() + return ORJSONResponse({}, status_code=200) + + @app.post("/update_weights_from_disk") async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: Request): """Update the weights from disk inplace without re-launching the server.""" @@ -640,10 +646,10 @@ async def vertex_generate(vertex_req: VertexGenerateReqInput, raw_request: Reque ] break image_data = [ - instance.get("image_data") - for instance in vertex_req.instances - if instance.get("image_data") is not None - ] or None + instance.get("image_data") + for instance in vertex_req.instances + if instance.get("image_data") is not None + ] or None req = GenerateReqInput( **inputs, image_data=image_data, diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 64b3cc5db65..ba50c4818e7 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -45,7 +45,6 @@ import zmq import zmq.asyncio from fastapi import BackgroundTasks - from sglang.srt.aio_rwlock import RWLock from sglang.srt.configs.model_config import ModelConfig from sglang.srt.disaggregation.conn import KVBootstrapServer @@ -1024,8 +1023,8 @@ def _handle_batch_output( elif isinstance(recv_obj, BatchTokenIDOut): if self.server_args.stream_output and state.obj.stream: output_token_ids = recv_obj.output_ids[i][ - state.last_output_offset : - ] + state.last_output_offset: + ] state.last_output_offset = len(recv_obj.output_ids[i]) else: output_token_ids = recv_obj.output_ids[i] From de3ecd279e7ab0ff9b530e83025ef611d65fedb2 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 09:38:10 +0800 Subject: [PATCH 0529/1089] more --- python/sglang/srt/managers/expert_location.py | 25 +++++++++++-------- .../sglang/srt/model_executor/model_runner.py | 6 ++++- 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 48bef6a39ab..c39ca46ad6b 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -1,9 +1,10 @@ +import dataclasses +import json import random from dataclasses import dataclass from typing import List, Optional import torch - from sglang.srt.configs.model_config import ModelConfig from sglang.srt.managers import deepseek_eplb from sglang.srt.managers.schedule_batch import global_server_args_dict @@ -33,8 +34,8 @@ def init_trivial(server_args: ServerArgs): num_logical_experts = model_config_for_expert_location.num_logical_experts physical_to_logical_map = ( - torch.arange(0, num_physical_experts).repeat(num_layers, 1) - % num_logical_experts + torch.arange(0, num_physical_experts).repeat(num_layers, 1) + % num_logical_experts ) # Highly inefficient, but we do not care since we will use EPLB distribution logic logical_to_all_physical_map = torch.arange(0, num_logical_experts).repeat( @@ -49,7 +50,7 @@ def init_trivial(server_args: ServerArgs): @staticmethod def init_by_mapping( - server_args: ServerArgs, physical_to_logical_map, logical_to_all_physical_map + server_args: ServerArgs, physical_to_logical_map, logical_to_all_physical_map ): if not isinstance(physical_to_logical_map, torch.Tensor): physical_to_logical_map = torch.tensor(physical_to_logical_map) @@ -104,8 +105,8 @@ def _init_common(server_args: ServerArgs): ) num_physical_experts = ( - model_config_for_expert_location.num_logical_experts - + server_args.ep_num_redundant_experts + model_config_for_expert_location.num_logical_experts + + server_args.ep_num_redundant_experts ) # TODO consider case when DP attention is disabled and DP > 1 world_size = server_args.tp_size @@ -148,7 +149,7 @@ def physical_to_local_physical(self, global_physical_expert_index: int): return global_physical_expert_index % self.num_local_physical_experts def logical_to_all_physical( - self, layer_id: int, logical_expert_id: int + self, layer_id: int, logical_expert_id: int ) -> List[int]: return self.logical_to_all_physical_raw( self.logical_to_all_physical_map, layer_id, logical_expert_id @@ -156,7 +157,7 @@ def logical_to_all_physical( @staticmethod def logical_to_all_physical_raw( - logical_to_all_physical_map, layer_id: int, logical_expert_id: int + logical_to_all_physical_map, layer_id: int, logical_expert_id: int ) -> List[int]: return [ physical_expert_id @@ -166,10 +167,14 @@ def logical_to_all_physical_raw( if physical_expert_id != -1 ] + def debug_str(self): + return json.dumps( + {k: v.tolist() if isinstance(v, torch.Tensor) else v for k, v in dataclasses.asdict(self).items()}) + def _compute_logical_to_rank_dispatch_physical_map( - logical_to_all_physical_map: torch.Tensor, - num_gpus: int, + logical_to_all_physical_map: torch.Tensor, + num_gpus: int, ): # TODO maybe improve this algorithm (e.g. ensure it is really balanced) # This is rarely called, so we use for loops for maximum clarity diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 41e862e8228..3f3c5cdf186 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -87,7 +87,7 @@ monkey_patch_p2p_access_check, monkey_patch_vllm_gguf_config, set_cpu_offload_max_bytes, - set_cuda_arch, + set_cuda_arch, get_bool_env_var, ) logger = logging.getLogger(__name__) @@ -179,6 +179,8 @@ def __init__( } ) set_global_expert_location_metadata(expert_location_metadata) + if self.tp_rank == 0 and get_bool_env_var("SGLANG_LOG_EXPERT_LOCATION_METADATA"): + logger.info(f"Initial expert_location_metadata: {get_global_expert_location_metadata().debug_str()}") # CPU offload set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3)) @@ -481,6 +483,8 @@ def update_expert_location(self, recv_req: UpdateExpertLocationReqInput): torch.distributed.barrier() get_global_expert_location_metadata().update(recv_req.expert_location_metadata) + if self.tp_rank == 0 and get_bool_env_var("SGLANG_LOG_EXPERT_LOCATION_METADATA"): + logger.info(f"Updated expert_location_metadata: {get_global_expert_location_metadata().debug_str()}") # We may be able to further reduce lock time by faster copying, pre-transfering, etc self.update_weights_from_disk( From 3a61ef3f335c7d3f6e937ea87f4d5333968ec218 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 09:38:39 +0800 Subject: [PATCH 0530/1089] fmt --- python/sglang/srt/entrypoints/engine.py | 5 +++- python/sglang/srt/entrypoints/http_server.py | 10 ++++---- python/sglang/srt/managers/expert_location.py | 25 +++++++++++-------- .../sglang/srt/managers/tokenizer_manager.py | 5 ++-- .../sglang/srt/model_executor/model_runner.py | 19 ++++++++++---- 5 files changed, 41 insertions(+), 23 deletions(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 8cc44807703..a2717c9005c 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -31,6 +31,7 @@ import zmq import zmq.asyncio from PIL.Image import Image + from sglang.srt.configs.model_config import ModelConfig from sglang.srt.managers.eplb_manager import EPLBManager from sglang.srt.managers.expert_location import ExpertLocationMetadata @@ -369,7 +370,9 @@ def eplb_rebalance(self): def eplb_save_expert_distribution(self): loop = asyncio.get_event_loop() - return loop.run_until_complete(self.tokenizer_manager.eplb_save_expert_distribution()) + return loop.run_until_complete( + self.tokenizer_manager.eplb_save_expert_distribution() + ) def update_expert_location(self, expert_location_metadata: ExpertLocationMetadata): obj = UpdateExpertLocationReqInput( diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index a9aa47d2e27..9ddf72cad99 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -313,7 +313,7 @@ async def flush_cache(): ret = await _global_state.tokenizer_manager.flush_cache() return Response( content="Cache flushed.\nPlease check backend logs for more details. " - "(When there are running or waiting requests, the operation will not be performed.)\n", + "(When there are running or waiting requests, the operation will not be performed.)\n", status_code=200 if ret.success else HTTPStatus.BAD_REQUEST, ) @@ -646,10 +646,10 @@ async def vertex_generate(vertex_req: VertexGenerateReqInput, raw_request: Reque ] break image_data = [ - instance.get("image_data") - for instance in vertex_req.instances - if instance.get("image_data") is not None - ] or None + instance.get("image_data") + for instance in vertex_req.instances + if instance.get("image_data") is not None + ] or None req = GenerateReqInput( **inputs, image_data=image_data, diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index c39ca46ad6b..d5f54319641 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -5,6 +5,7 @@ from typing import List, Optional import torch + from sglang.srt.configs.model_config import ModelConfig from sglang.srt.managers import deepseek_eplb from sglang.srt.managers.schedule_batch import global_server_args_dict @@ -34,8 +35,8 @@ def init_trivial(server_args: ServerArgs): num_logical_experts = model_config_for_expert_location.num_logical_experts physical_to_logical_map = ( - torch.arange(0, num_physical_experts).repeat(num_layers, 1) - % num_logical_experts + torch.arange(0, num_physical_experts).repeat(num_layers, 1) + % num_logical_experts ) # Highly inefficient, but we do not care since we will use EPLB distribution logic logical_to_all_physical_map = torch.arange(0, num_logical_experts).repeat( @@ -50,7 +51,7 @@ def init_trivial(server_args: ServerArgs): @staticmethod def init_by_mapping( - server_args: ServerArgs, physical_to_logical_map, logical_to_all_physical_map + server_args: ServerArgs, physical_to_logical_map, logical_to_all_physical_map ): if not isinstance(physical_to_logical_map, torch.Tensor): physical_to_logical_map = torch.tensor(physical_to_logical_map) @@ -105,8 +106,8 @@ def _init_common(server_args: ServerArgs): ) num_physical_experts = ( - model_config_for_expert_location.num_logical_experts - + server_args.ep_num_redundant_experts + model_config_for_expert_location.num_logical_experts + + server_args.ep_num_redundant_experts ) # TODO consider case when DP attention is disabled and DP > 1 world_size = server_args.tp_size @@ -149,7 +150,7 @@ def physical_to_local_physical(self, global_physical_expert_index: int): return global_physical_expert_index % self.num_local_physical_experts def logical_to_all_physical( - self, layer_id: int, logical_expert_id: int + self, layer_id: int, logical_expert_id: int ) -> List[int]: return self.logical_to_all_physical_raw( self.logical_to_all_physical_map, layer_id, logical_expert_id @@ -157,7 +158,7 @@ def logical_to_all_physical( @staticmethod def logical_to_all_physical_raw( - logical_to_all_physical_map, layer_id: int, logical_expert_id: int + logical_to_all_physical_map, layer_id: int, logical_expert_id: int ) -> List[int]: return [ physical_expert_id @@ -169,12 +170,16 @@ def logical_to_all_physical_raw( def debug_str(self): return json.dumps( - {k: v.tolist() if isinstance(v, torch.Tensor) else v for k, v in dataclasses.asdict(self).items()}) + { + k: v.tolist() if isinstance(v, torch.Tensor) else v + for k, v in dataclasses.asdict(self).items() + } + ) def _compute_logical_to_rank_dispatch_physical_map( - logical_to_all_physical_map: torch.Tensor, - num_gpus: int, + logical_to_all_physical_map: torch.Tensor, + num_gpus: int, ): # TODO maybe improve this algorithm (e.g. ensure it is really balanced) # This is rarely called, so we use for loops for maximum clarity diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index ba50c4818e7..64b3cc5db65 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -45,6 +45,7 @@ import zmq import zmq.asyncio from fastapi import BackgroundTasks + from sglang.srt.aio_rwlock import RWLock from sglang.srt.configs.model_config import ModelConfig from sglang.srt.disaggregation.conn import KVBootstrapServer @@ -1023,8 +1024,8 @@ def _handle_batch_output( elif isinstance(recv_obj, BatchTokenIDOut): if self.server_args.stream_output and state.obj.stream: output_token_ids = recv_obj.output_ids[i][ - state.last_output_offset: - ] + state.last_output_offset : + ] state.last_output_offset = len(recv_obj.output_ids[i]) else: output_token_ids = recv_obj.output_ids[i] diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 3f3c5cdf186..174cc1cf2b1 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -80,6 +80,7 @@ MultiprocessingSerializer, enable_show_time_cost, get_available_gpu_memory, + get_bool_env_var, init_custom_process_group, is_cuda, is_flashinfer_available, @@ -87,7 +88,7 @@ monkey_patch_p2p_access_check, monkey_patch_vllm_gguf_config, set_cpu_offload_max_bytes, - set_cuda_arch, get_bool_env_var, + set_cuda_arch, ) logger = logging.getLogger(__name__) @@ -179,8 +180,12 @@ def __init__( } ) set_global_expert_location_metadata(expert_location_metadata) - if self.tp_rank == 0 and get_bool_env_var("SGLANG_LOG_EXPERT_LOCATION_METADATA"): - logger.info(f"Initial expert_location_metadata: {get_global_expert_location_metadata().debug_str()}") + if self.tp_rank == 0 and get_bool_env_var( + "SGLANG_LOG_EXPERT_LOCATION_METADATA" + ): + logger.info( + f"Initial expert_location_metadata: {get_global_expert_location_metadata().debug_str()}" + ) # CPU offload set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3)) @@ -483,8 +488,12 @@ def update_expert_location(self, recv_req: UpdateExpertLocationReqInput): torch.distributed.barrier() get_global_expert_location_metadata().update(recv_req.expert_location_metadata) - if self.tp_rank == 0 and get_bool_env_var("SGLANG_LOG_EXPERT_LOCATION_METADATA"): - logger.info(f"Updated expert_location_metadata: {get_global_expert_location_metadata().debug_str()}") + if self.tp_rank == 0 and get_bool_env_var( + "SGLANG_LOG_EXPERT_LOCATION_METADATA" + ): + logger.info( + f"Updated expert_location_metadata: {get_global_expert_location_metadata().debug_str()}" + ) # We may be able to further reduce lock time by faster copying, pre-transfering, etc self.update_weights_from_disk( From b264b77e78c41e0258f7b6dbd6b255af3886f33a Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 09:41:18 +0800 Subject: [PATCH 0531/1089] more --- python/sglang/srt/entrypoints/engine.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index a2717c9005c..d047bfb621b 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -31,7 +31,6 @@ import zmq import zmq.asyncio from PIL.Image import Image - from sglang.srt.configs.model_config import ModelConfig from sglang.srt.managers.eplb_manager import EPLBManager from sglang.srt.managers.expert_location import ExpertLocationMetadata @@ -649,8 +648,14 @@ def _compute_initial_expert_location_metadata( server_args: ServerArgs, eplb_manager: EPLBManager ) -> ExpertLocationMetadata: if (data := server_args.init_expert_location) is not None: - # TODO We may want to allow users to not provide `logical_to_all_physical_map` if this API is frequently used - return ExpertLocationMetadata.init_by_mapping(server_args, **json.loads(data)) + data_dict = json.loads(data) + if "physical_to_logical_map" in data_dict: + # TODO We may want to allow users to not provide `logical_to_all_physical_map` if this API is frequently used + return ExpertLocationMetadata.init_by_mapping(server_args, **data_dict) + elif "logical_count" in data_dict: + return ExpertLocationMetadata.init_by_eplb(server_args, **data_dict) + else: + raise NotImplementedError(f"Unknown init_expert_location format ({list(data_dict.keys())=})") if server_args.enable_eplb: return eplb_manager.compute_expert_location_metadata() return ExpertLocationMetadata.init_trivial(server_args) From 3e8f979c2451d3206ea750828fd028cae1482f4c Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 09:44:53 +0800 Subject: [PATCH 0532/1089] more --- test/srt/test_eplb.py | 37 ++++++++++++++++++++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index bda3bc9bd3d..c26ac298ac1 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -55,8 +55,43 @@ def test_eplb_e2e(self): engine.shutdown() del engine + def test_eplb_init_expert_location_and_save_expert_distribution(self): + with tempfile.TemporaryDirectory() as eplb_storage_dir_a, tempfile.TemporaryDirectory() as eplb_storage_dir_b: + engine_kwargs = dict( + model_path=DEFAULT_MLA_MODEL_NAME_FOR_TEST, + enable_eplb=True, + ep_num_redundant_experts=_EP_NUM_REDUNDANT_EXPERTS, + tp_size=2, + ) + + print(f"Action: start engine") + engine = sgl.Engine(**engine_kwargs, eplb_storage_dir=eplb_storage_dir_a) + ref_output = self._engine_generate(engine) + self._assert_behavior(engine, ref_output, "equal_trivial") + + print(f"Action: eplb_save_expert_distribution") + engine.eplb_save_expert_distribution() + + print(f"Action: shutdown engine") + engine.shutdown() + del engine + + print(f"Action: start engine with init_expert_location") + engine = sgl.Engine(**engine_kwargs, eplb_storage_dir=eplb_storage_dir_b, init_expert_location=TODO) + self._assert_behavior(engine, ref_output, "not_equal_trivial") + print(f"Action: shutdown engine") + engine.shutdown() + del engine + + print(f"Action: start engine to check automatically loading from storage dir") + engine = sgl.Engine(**engine_kwargs, eplb_storage_dir=eplb_storage_dir_a) + self._assert_behavior(engine, ref_output, "not_equal_trivial") + print(f"Action: shutdown engine") + engine.shutdown() + del engine + def _assert_behavior( - self, engine: sgl.Engine, ref_output: List[str], expect_physical_to_local_map + self, engine: sgl.Engine, ref_output: List[str], expect_physical_to_local_map ): ret = engine.flush_cache() assert ret.success From c38a21f253a27a5e79d1ddc147154e9006e258c6 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 09:45:52 +0800 Subject: [PATCH 0533/1089] more --- .../srt/managers/expert_distribution_storage.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution_storage.py b/python/sglang/srt/managers/expert_distribution_storage.py index 9c9bffebf24..7111026a3c7 100644 --- a/python/sglang/srt/managers/expert_distribution_storage.py +++ b/python/sglang/srt/managers/expert_distribution_storage.py @@ -28,9 +28,15 @@ async def save_current(self): path.write_text(json.dumps(data)) def get_last_snapshot(self) -> Optional[Dict[str, Any]]: - paths = sorted(list(self._dir_data.glob("*.json")), key=lambda p: int(p.stem)) - if len(paths) == 0: + path = self.get_last_snapshot_path(self._dir_data) + if path is None: return None - path = paths[-1] logger.info(f"get_last_snapshot choose path {path}") return json.loads(path.read_text()) + + @staticmethod + def get_last_snapshot_path(dir_data: Path) -> Optional[Path]: + paths = sorted(list(dir_data.glob("*.json")), key=lambda p: int(p.stem)) + if len(paths) == 0: + return None + return paths[-1] From 247c67563dd4b91f043a90f85e14a741bda2a0fc Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 09:47:00 +0800 Subject: [PATCH 0534/1089] more --- python/sglang/srt/managers/expert_distribution_storage.py | 4 ++-- test/srt/test_eplb.py | 6 +++++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution_storage.py b/python/sglang/srt/managers/expert_distribution_storage.py index 7111026a3c7..8e785884119 100644 --- a/python/sglang/srt/managers/expert_distribution_storage.py +++ b/python/sglang/srt/managers/expert_distribution_storage.py @@ -35,8 +35,8 @@ def get_last_snapshot(self) -> Optional[Dict[str, Any]]: return json.loads(path.read_text()) @staticmethod - def get_last_snapshot_path(dir_data: Path) -> Optional[Path]: - paths = sorted(list(dir_data.glob("*.json")), key=lambda p: int(p.stem)) + def get_last_snapshot_path(dir_data) -> Optional[Path]: + paths = sorted(list(Path(dir_data).glob("*.json")), key=lambda p: int(p.stem)) if len(paths) == 0: return None return paths[-1] diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index c26ac298ac1..630a07cb58e 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -3,6 +3,7 @@ from typing import List import sglang as sgl +from python.sglang.srt.managers.expert_distribution_storage import ExpertDistributionStorage from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( DEFAULT_MLA_MODEL_NAME_FOR_TEST, @@ -71,13 +72,16 @@ def test_eplb_init_expert_location_and_save_expert_distribution(self): print(f"Action: eplb_save_expert_distribution") engine.eplb_save_expert_distribution() + snapshot_path = ExpertDistributionStorage.get_last_snapshot_path(eplb_storage_dir_a) + assert snapshot_path is not None print(f"Action: shutdown engine") engine.shutdown() del engine print(f"Action: start engine with init_expert_location") - engine = sgl.Engine(**engine_kwargs, eplb_storage_dir=eplb_storage_dir_b, init_expert_location=TODO) + engine = sgl.Engine(**engine_kwargs, eplb_storage_dir=eplb_storage_dir_b, + init_expert_location=str(snapshot_path)) self._assert_behavior(engine, ref_output, "not_equal_trivial") print(f"Action: shutdown engine") engine.shutdown() From bba50db7b025d971220281deaf7c1552b406bbc5 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 09:47:28 +0800 Subject: [PATCH 0535/1089] fmt --- python/sglang/srt/entrypoints/engine.py | 5 ++++- test/srt/test_eplb.py | 21 +++++++++++++++------ 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index d047bfb621b..6c945b884f6 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -31,6 +31,7 @@ import zmq import zmq.asyncio from PIL.Image import Image + from sglang.srt.configs.model_config import ModelConfig from sglang.srt.managers.eplb_manager import EPLBManager from sglang.srt.managers.expert_location import ExpertLocationMetadata @@ -655,7 +656,9 @@ def _compute_initial_expert_location_metadata( elif "logical_count" in data_dict: return ExpertLocationMetadata.init_by_eplb(server_args, **data_dict) else: - raise NotImplementedError(f"Unknown init_expert_location format ({list(data_dict.keys())=})") + raise NotImplementedError( + f"Unknown init_expert_location format ({list(data_dict.keys())=})" + ) if server_args.enable_eplb: return eplb_manager.compute_expert_location_metadata() return ExpertLocationMetadata.init_trivial(server_args) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index 630a07cb58e..f4abf7d64db 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -3,7 +3,9 @@ from typing import List import sglang as sgl -from python.sglang.srt.managers.expert_distribution_storage import ExpertDistributionStorage +from python.sglang.srt.managers.expert_distribution_storage import ( + ExpertDistributionStorage, +) from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( DEFAULT_MLA_MODEL_NAME_FOR_TEST, @@ -72,7 +74,9 @@ def test_eplb_init_expert_location_and_save_expert_distribution(self): print(f"Action: eplb_save_expert_distribution") engine.eplb_save_expert_distribution() - snapshot_path = ExpertDistributionStorage.get_last_snapshot_path(eplb_storage_dir_a) + snapshot_path = ExpertDistributionStorage.get_last_snapshot_path( + eplb_storage_dir_a + ) assert snapshot_path is not None print(f"Action: shutdown engine") @@ -80,14 +84,19 @@ def test_eplb_init_expert_location_and_save_expert_distribution(self): del engine print(f"Action: start engine with init_expert_location") - engine = sgl.Engine(**engine_kwargs, eplb_storage_dir=eplb_storage_dir_b, - init_expert_location=str(snapshot_path)) + engine = sgl.Engine( + **engine_kwargs, + eplb_storage_dir=eplb_storage_dir_b, + init_expert_location=str(snapshot_path), + ) self._assert_behavior(engine, ref_output, "not_equal_trivial") print(f"Action: shutdown engine") engine.shutdown() del engine - print(f"Action: start engine to check automatically loading from storage dir") + print( + f"Action: start engine to check automatically loading from storage dir" + ) engine = sgl.Engine(**engine_kwargs, eplb_storage_dir=eplb_storage_dir_a) self._assert_behavior(engine, ref_output, "not_equal_trivial") print(f"Action: shutdown engine") @@ -95,7 +104,7 @@ def test_eplb_init_expert_location_and_save_expert_distribution(self): del engine def _assert_behavior( - self, engine: sgl.Engine, ref_output: List[str], expect_physical_to_local_map + self, engine: sgl.Engine, ref_output: List[str], expect_physical_to_local_map ): ret = engine.flush_cache() assert ret.success From 1dc2c63217b9f1176a1540b52205f3443a94e73d Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 09:49:04 +0800 Subject: [PATCH 0536/1089] more --- test/srt/test_eplb.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index f4abf7d64db..0f02c176cd7 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -37,10 +37,9 @@ def test_eplb_e2e(self): self._assert_behavior(engine, ref_output, "equal_trivial") print(f"Action: eplb_rebalance") - physical_to_logical_map_layer_0_after_first_rebalance = ( - engine.eplb_rebalance() - ) - self._assert_behavior(engine, ref_output, "not_equal_trivial") + engine.eplb_rebalance() + physical_to_logical_map_layer_0_after_first_rebalance = self._assert_behavior(engine, ref_output, + "not_equal_trivial") print(f"Action: shutdown engine") engine.shutdown() @@ -104,7 +103,7 @@ def test_eplb_init_expert_location_and_save_expert_distribution(self): del engine def _assert_behavior( - self, engine: sgl.Engine, ref_output: List[str], expect_physical_to_local_map + self, engine: sgl.Engine, ref_output: List[str], expect_physical_to_local_map ): ret = engine.flush_cache() assert ret.success From 666823fbee05c51d678278e06a54cb239d018536 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 09:49:33 +0800 Subject: [PATCH 0537/1089] mv --- python/sglang/srt/model_executor/model_runner.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 174cc1cf2b1..659d8d0b363 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -179,6 +179,13 @@ def __init__( "use_mla_backend": self.use_mla_backend, } ) + + # CPU offload + set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3)) + + # Get memory before model loading + min_per_gpu_memory = self.init_torch_distributed() + set_global_expert_location_metadata(expert_location_metadata) if self.tp_rank == 0 and get_bool_env_var( "SGLANG_LOG_EXPERT_LOCATION_METADATA" @@ -187,12 +194,6 @@ def __init__( f"Initial expert_location_metadata: {get_global_expert_location_metadata().debug_str()}" ) - # CPU offload - set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3)) - - # Get memory before model loading - min_per_gpu_memory = self.init_torch_distributed() - # If it is a draft model tp_group can be different. self.initialize(min_per_gpu_memory) From 40c51583955a25b42ddd9a7659eaa3f6580f72d2 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 09:50:17 +0800 Subject: [PATCH 0538/1089] more --- python/sglang/srt/managers/expert_location.py | 22 ++++++++++--------- .../sglang/srt/model_executor/model_runner.py | 1 + 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index d5f54319641..d9ccc0d071f 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -5,7 +5,6 @@ from typing import List, Optional import torch - from sglang.srt.configs.model_config import ModelConfig from sglang.srt.managers import deepseek_eplb from sglang.srt.managers.schedule_batch import global_server_args_dict @@ -35,8 +34,8 @@ def init_trivial(server_args: ServerArgs): num_logical_experts = model_config_for_expert_location.num_logical_experts physical_to_logical_map = ( - torch.arange(0, num_physical_experts).repeat(num_layers, 1) - % num_logical_experts + torch.arange(0, num_physical_experts).repeat(num_layers, 1) + % num_logical_experts ) # Highly inefficient, but we do not care since we will use EPLB distribution logic logical_to_all_physical_map = torch.arange(0, num_logical_experts).repeat( @@ -51,7 +50,7 @@ def init_trivial(server_args: ServerArgs): @staticmethod def init_by_mapping( - server_args: ServerArgs, physical_to_logical_map, logical_to_all_physical_map + server_args: ServerArgs, physical_to_logical_map, logical_to_all_physical_map ): if not isinstance(physical_to_logical_map, torch.Tensor): physical_to_logical_map = torch.tensor(physical_to_logical_map) @@ -106,8 +105,8 @@ def _init_common(server_args: ServerArgs): ) num_physical_experts = ( - model_config_for_expert_location.num_logical_experts - + server_args.ep_num_redundant_experts + model_config_for_expert_location.num_logical_experts + + server_args.ep_num_redundant_experts ) # TODO consider case when DP attention is disabled and DP > 1 world_size = server_args.tp_size @@ -141,6 +140,9 @@ def update(self, other: "ExpertLocationMetadata"): ]: setattr(self, field, getattr(other, field)) + def to(self, device): + self.logical_to_rank_dispatch_physical_map = self.logical_to_rank_dispatch_physical_map.to(device) + # -------------------------------- usage ------------------------------------ def local_physical_to_physical(self, rank: int, local_physical_expert_index: int): @@ -150,7 +152,7 @@ def physical_to_local_physical(self, global_physical_expert_index: int): return global_physical_expert_index % self.num_local_physical_experts def logical_to_all_physical( - self, layer_id: int, logical_expert_id: int + self, layer_id: int, logical_expert_id: int ) -> List[int]: return self.logical_to_all_physical_raw( self.logical_to_all_physical_map, layer_id, logical_expert_id @@ -158,7 +160,7 @@ def logical_to_all_physical( @staticmethod def logical_to_all_physical_raw( - logical_to_all_physical_map, layer_id: int, logical_expert_id: int + logical_to_all_physical_map, layer_id: int, logical_expert_id: int ) -> List[int]: return [ physical_expert_id @@ -178,8 +180,8 @@ def debug_str(self): def _compute_logical_to_rank_dispatch_physical_map( - logical_to_all_physical_map: torch.Tensor, - num_gpus: int, + logical_to_all_physical_map: torch.Tensor, + num_gpus: int, ): # TODO maybe improve this algorithm (e.g. ensure it is really balanced) # This is rarely called, so we use for loops for maximum clarity diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 659d8d0b363..297fb6b005d 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -186,6 +186,7 @@ def __init__( # Get memory before model loading min_per_gpu_memory = self.init_torch_distributed() + expert_location_metadata.to(server_args.device) set_global_expert_location_metadata(expert_location_metadata) if self.tp_rank == 0 and get_bool_env_var( "SGLANG_LOG_EXPERT_LOCATION_METADATA" From e8779333118c2d16a52bae503789111037e32b71 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 09:50:37 +0800 Subject: [PATCH 0539/1089] more --- python/sglang/srt/managers/expert_location.py | 23 +++++++++++-------- test/srt/test_eplb.py | 7 +++--- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index d9ccc0d071f..8502931625b 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -5,6 +5,7 @@ from typing import List, Optional import torch + from sglang.srt.configs.model_config import ModelConfig from sglang.srt.managers import deepseek_eplb from sglang.srt.managers.schedule_batch import global_server_args_dict @@ -34,8 +35,8 @@ def init_trivial(server_args: ServerArgs): num_logical_experts = model_config_for_expert_location.num_logical_experts physical_to_logical_map = ( - torch.arange(0, num_physical_experts).repeat(num_layers, 1) - % num_logical_experts + torch.arange(0, num_physical_experts).repeat(num_layers, 1) + % num_logical_experts ) # Highly inefficient, but we do not care since we will use EPLB distribution logic logical_to_all_physical_map = torch.arange(0, num_logical_experts).repeat( @@ -50,7 +51,7 @@ def init_trivial(server_args: ServerArgs): @staticmethod def init_by_mapping( - server_args: ServerArgs, physical_to_logical_map, logical_to_all_physical_map + server_args: ServerArgs, physical_to_logical_map, logical_to_all_physical_map ): if not isinstance(physical_to_logical_map, torch.Tensor): physical_to_logical_map = torch.tensor(physical_to_logical_map) @@ -105,8 +106,8 @@ def _init_common(server_args: ServerArgs): ) num_physical_experts = ( - model_config_for_expert_location.num_logical_experts - + server_args.ep_num_redundant_experts + model_config_for_expert_location.num_logical_experts + + server_args.ep_num_redundant_experts ) # TODO consider case when DP attention is disabled and DP > 1 world_size = server_args.tp_size @@ -141,7 +142,9 @@ def update(self, other: "ExpertLocationMetadata"): setattr(self, field, getattr(other, field)) def to(self, device): - self.logical_to_rank_dispatch_physical_map = self.logical_to_rank_dispatch_physical_map.to(device) + self.logical_to_rank_dispatch_physical_map = ( + self.logical_to_rank_dispatch_physical_map.to(device) + ) # -------------------------------- usage ------------------------------------ @@ -152,7 +155,7 @@ def physical_to_local_physical(self, global_physical_expert_index: int): return global_physical_expert_index % self.num_local_physical_experts def logical_to_all_physical( - self, layer_id: int, logical_expert_id: int + self, layer_id: int, logical_expert_id: int ) -> List[int]: return self.logical_to_all_physical_raw( self.logical_to_all_physical_map, layer_id, logical_expert_id @@ -160,7 +163,7 @@ def logical_to_all_physical( @staticmethod def logical_to_all_physical_raw( - logical_to_all_physical_map, layer_id: int, logical_expert_id: int + logical_to_all_physical_map, layer_id: int, logical_expert_id: int ) -> List[int]: return [ physical_expert_id @@ -180,8 +183,8 @@ def debug_str(self): def _compute_logical_to_rank_dispatch_physical_map( - logical_to_all_physical_map: torch.Tensor, - num_gpus: int, + logical_to_all_physical_map: torch.Tensor, + num_gpus: int, ): # TODO maybe improve this algorithm (e.g. ensure it is really balanced) # This is rarely called, so we use for loops for maximum clarity diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index 0f02c176cd7..df6557e0b50 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -38,8 +38,9 @@ def test_eplb_e2e(self): print(f"Action: eplb_rebalance") engine.eplb_rebalance() - physical_to_logical_map_layer_0_after_first_rebalance = self._assert_behavior(engine, ref_output, - "not_equal_trivial") + physical_to_logical_map_layer_0_after_first_rebalance = ( + self._assert_behavior(engine, ref_output, "not_equal_trivial") + ) print(f"Action: shutdown engine") engine.shutdown() @@ -103,7 +104,7 @@ def test_eplb_init_expert_location_and_save_expert_distribution(self): del engine def _assert_behavior( - self, engine: sgl.Engine, ref_output: List[str], expect_physical_to_local_map + self, engine: sgl.Engine, ref_output: List[str], expect_physical_to_local_map ): ret = engine.flush_cache() assert ret.success From 5ed6e300ff9886b6091d30052b80ab30c7b40ec0 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 09:55:06 +0800 Subject: [PATCH 0540/1089] more --- python/sglang/srt/managers/expert_distribution_storage.py | 5 +++-- python/sglang/srt/managers/expert_location.py | 1 - 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution_storage.py b/python/sglang/srt/managers/expert_distribution_storage.py index 8e785884119..a5fb6e1af89 100644 --- a/python/sglang/srt/managers/expert_distribution_storage.py +++ b/python/sglang/srt/managers/expert_distribution_storage.py @@ -4,7 +4,8 @@ from pathlib import Path from typing import Any, Dict, Optional -from sglang.srt.managers.tokenizer_manager import TokenizerManager +if TYPE_CHECKING: + from sglang.srt.managers.tokenizer_manager import TokenizerManager logger = logging.getLogger(__name__) @@ -15,7 +16,7 @@ def __init__(self, dir_data): if not self._dir_data.exists(): self._dir_data.mkdir(parents=True, exist_ok=True) - def bind(self, tokenizer_manager: TokenizerManager): + def bind(self, tokenizer_manager: "TokenizerManager"): self._tokenizer_manager = tokenizer_manager async def start(self): diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 8502931625b..da07e07afd6 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -8,7 +8,6 @@ from sglang.srt.configs.model_config import ModelConfig from sglang.srt.managers import deepseek_eplb -from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_loader import get_model_architecture from sglang.srt.server_args import ServerArgs From 2c4be1f395dde8c0cedfdf70a1884015d968453d Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 09:55:42 +0800 Subject: [PATCH 0541/1089] more --- python/sglang/srt/managers/expert_distribution_storage.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_distribution_storage.py b/python/sglang/srt/managers/expert_distribution_storage.py index a5fb6e1af89..739a58062d2 100644 --- a/python/sglang/srt/managers/expert_distribution_storage.py +++ b/python/sglang/srt/managers/expert_distribution_storage.py @@ -2,7 +2,7 @@ import logging import time from pathlib import Path -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, TYPE_CHECKING if TYPE_CHECKING: from sglang.srt.managers.tokenizer_manager import TokenizerManager From 86e354fa242bfd60da7d92b7579ede11cd29feab Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 09:57:36 +0800 Subject: [PATCH 0542/1089] more --- python/sglang/srt/layers/moe/topk.py | 6 +----- python/sglang/srt/managers/expert_distribution_storage.py | 2 +- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index b41fc5430bc..0fd70f88c02 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -18,11 +18,7 @@ import torch.nn.functional as F from sglang.srt.managers.expert_distribution import expert_distribution_recorder -from sglang.srt.managers.schedule_batch import ( - get_global_expert_location_metadata, - global_expert_location_metadata, - global_server_args_dict, -) +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip _is_cuda = is_cuda() diff --git a/python/sglang/srt/managers/expert_distribution_storage.py b/python/sglang/srt/managers/expert_distribution_storage.py index 739a58062d2..ebc51daec7f 100644 --- a/python/sglang/srt/managers/expert_distribution_storage.py +++ b/python/sglang/srt/managers/expert_distribution_storage.py @@ -2,7 +2,7 @@ import logging import time from pathlib import Path -from typing import Any, Dict, Optional, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Dict, Optional if TYPE_CHECKING: from sglang.srt.managers.tokenizer_manager import TokenizerManager From 9feccffff154ac3b748d10439f5e20ccbb4a91f8 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 09:58:32 +0800 Subject: [PATCH 0543/1089] more --- python/sglang/srt/models/deepseek_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index e0fde56d9b1..fe7722b2db9 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1626,7 +1626,7 @@ def get_model_config_for_expert_location(cls, config): return ModelConfigForExpertLocation( num_layers=config.num_hidden_layers, num_logical_experts=config.n_routed_experts, - num_groups=config.n_groups, + num_groups=config.n_group, ) From 36cba1b06a9aa8d44221d83eb9d8460c0efb5c19 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 10:07:34 +0800 Subject: [PATCH 0544/1089] more --- test/srt/test_eplb.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index df6557e0b50..e8501b5400b 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -3,9 +3,7 @@ from typing import List import sglang as sgl -from python.sglang.srt.managers.expert_distribution_storage import ( - ExpertDistributionStorage, -) +from sglang.srt.managers.expert_distribution_storage import ExpertDistributionStorage from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( DEFAULT_MLA_MODEL_NAME_FOR_TEST, From e9040f6e84399e998a2bb2102787a27f8d8e9a1c Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 10:08:51 +0800 Subject: [PATCH 0545/1089] more --- test/srt/test_eplb.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index e8501b5400b..e2ae57a826f 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -26,6 +26,9 @@ def test_eplb_e2e(self): enable_eplb=True, eplb_storage_dir=tmpdir, ep_num_redundant_experts=_EP_NUM_REDUNDANT_EXPERTS, + enable_deepep_moe=True, + deepep_moe_mode="normal", + disable_cuda_graph=True, tp_size=2, ) @@ -62,6 +65,9 @@ def test_eplb_init_expert_location_and_save_expert_distribution(self): model_path=DEFAULT_MLA_MODEL_NAME_FOR_TEST, enable_eplb=True, ep_num_redundant_experts=_EP_NUM_REDUNDANT_EXPERTS, + enable_deepep_moe=True, + deepep_moe_mode="normal", + disable_cuda_graph=True, tp_size=2, ) From 3e5c230614ae957afe4a1c2195754140586ba85d Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 10:10:52 +0800 Subject: [PATCH 0546/1089] more --- test/srt/test_eplb.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index e2ae57a826f..a1758eaf0aa 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -27,7 +27,7 @@ def test_eplb_e2e(self): eplb_storage_dir=tmpdir, ep_num_redundant_experts=_EP_NUM_REDUNDANT_EXPERTS, enable_deepep_moe=True, - deepep_moe_mode="normal", + deepep_mode="normal", disable_cuda_graph=True, tp_size=2, ) @@ -66,7 +66,7 @@ def test_eplb_init_expert_location_and_save_expert_distribution(self): enable_eplb=True, ep_num_redundant_experts=_EP_NUM_REDUNDANT_EXPERTS, enable_deepep_moe=True, - deepep_moe_mode="normal", + deepep_mode="normal", disable_cuda_graph=True, tp_size=2, ) From 2b9f33133560b44a043a930df1f924717bb44290 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 10:12:35 +0800 Subject: [PATCH 0547/1089] more --- python/sglang/srt/managers/eplb_manager.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index 8aa7deefec0..b9336e45821 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -55,11 +55,7 @@ def save_expert_distribution(self): self._expert_distribution_storage.save_current() def compute_expert_location_metadata(self): - logical_count = self._expert_distribution_storage.get_last_snapshot()[ - "logical_count" - ] - if logical_count is None: + snapshot = self._expert_distribution_storage.get_last_snapshot() + if snapshot is None: return ExpertLocationMetadata.init_trivial(self._server_args) - return ExpertLocationMetadata.init_by_eplb( - self._server_args, logical_count=logical_count - ) + return ExpertLocationMetadata.init_by_eplb(self._server_args, **snapshot) From 42c569d79ef33d075e2dd679b9fc1707d3fc9007 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 10:13:05 +0800 Subject: [PATCH 0548/1089] more --- test/srt/test_eplb.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index a1758eaf0aa..413e32f3c8e 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -23,6 +23,7 @@ def test_eplb_e2e(self): with tempfile.TemporaryDirectory() as tmpdir: engine_kwargs = dict( model_path=DEFAULT_MLA_MODEL_NAME_FOR_TEST, + trust_remote_code=True, enable_eplb=True, eplb_storage_dir=tmpdir, ep_num_redundant_experts=_EP_NUM_REDUNDANT_EXPERTS, @@ -63,6 +64,7 @@ def test_eplb_init_expert_location_and_save_expert_distribution(self): with tempfile.TemporaryDirectory() as eplb_storage_dir_a, tempfile.TemporaryDirectory() as eplb_storage_dir_b: engine_kwargs = dict( model_path=DEFAULT_MLA_MODEL_NAME_FOR_TEST, + trust_remote_code=True, enable_eplb=True, ep_num_redundant_experts=_EP_NUM_REDUNDANT_EXPERTS, enable_deepep_moe=True, From 4d4b0d6bb24678d3152016d37edb600314026dc6 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 10:16:22 +0800 Subject: [PATCH 0549/1089] more --- python/sglang/srt/models/deepseek_v2.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index fe7722b2db9..ea17909145b 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -254,7 +254,10 @@ def __init__( if global_server_args_dict["enable_deepep_moe"]: # TODO: we will support tp < ep in the future self.ep_size = get_tensor_model_parallel_world_size() - self.num_experts = config.n_routed_experts + self.num_experts = ( + config.n_routed_experts + + global_server_args_dict["ep_num_redundant_experts"] + ) self.top_k = config.num_experts_per_tok self.renormalize = config.norm_topk_prob self.topk_group = config.topk_group @@ -269,7 +272,8 @@ def __init__( group=parallel_state.get_tp_group().device_group, router_topk=self.top_k, permute_fusion=True, - num_experts=config.n_routed_experts, + num_experts=config.n_routed_experts + + global_server_args_dict["ep_num_redundant_experts"], num_local_experts=config.n_routed_experts // self.tp_size, hidden_size=config.hidden_size, params_dtype=config.torch_dtype, From c9b11cf2eefb687680b1c6466739d3612239cac5 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 10:25:50 +0800 Subject: [PATCH 0550/1089] more --- python/sglang/srt/managers/tokenizer_manager.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 64b3cc5db65..19099a2d866 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -646,7 +646,9 @@ async def _handle_batch_request( pass async def flush_cache(self) -> FlushCacheReqOutput: - return await self.flush_cache_communicator(FlushCacheReqInput()) + outputs = await self.flush_cache_communicator(FlushCacheReqInput()) + success = all(output.success for output in outputs) + return FlushCacheReqOutput(success=success) def abort_request(self, rid: str): if rid not in self.rid_to_state: From ccd7d38eba0f486865dc450b619d30e466b3450c Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 10:27:14 +0800 Subject: [PATCH 0551/1089] more --- python/sglang/srt/managers/io_struct.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 1f5ca625904..8c67b9f189c 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -677,6 +677,7 @@ class FlushCacheReqOutput: success: bool +@dataclass class UpdateExpertLocationReqInput: expert_location_metadata: "ExpertLocationMetadata" From dcd16c02eb8e0a5c13108cf8cada650492e6a8ad Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 10:29:47 +0800 Subject: [PATCH 0552/1089] more --- test/srt/test_eplb.py | 109 +++++++++++++++++++++--------------------- 1 file changed, 55 insertions(+), 54 deletions(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index 413e32f3c8e..794876bdf97 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -16,6 +16,9 @@ _NUM_ROUTED_EXPERTS = 64 # DeepSeek-Coder-V2-Lite-Instruct _EP_NUM_REDUNDANT_EXPERTS = 4 _NUM_OVERALL_PHYSICAL_EXPERTS = _NUM_ROUTED_EXPERTS + _EP_NUM_REDUNDANT_EXPERTS +_TRIVIAL_EXPERT_LOCATIONS = list( + x % _NUM_ROUTED_EXPERTS for x in range(_NUM_OVERALL_PHYSICAL_EXPERTS) +) class TestEPLB(CustomTestCase): @@ -30,6 +33,7 @@ def test_eplb_e2e(self): enable_deepep_moe=True, deepep_mode="normal", disable_cuda_graph=True, + enable_scheduler_input_blocker=True, tp_size=2, ) @@ -60,54 +64,55 @@ def test_eplb_e2e(self): engine.shutdown() del engine - def test_eplb_init_expert_location_and_save_expert_distribution(self): - with tempfile.TemporaryDirectory() as eplb_storage_dir_a, tempfile.TemporaryDirectory() as eplb_storage_dir_b: - engine_kwargs = dict( - model_path=DEFAULT_MLA_MODEL_NAME_FOR_TEST, - trust_remote_code=True, - enable_eplb=True, - ep_num_redundant_experts=_EP_NUM_REDUNDANT_EXPERTS, - enable_deepep_moe=True, - deepep_mode="normal", - disable_cuda_graph=True, - tp_size=2, - ) - - print(f"Action: start engine") - engine = sgl.Engine(**engine_kwargs, eplb_storage_dir=eplb_storage_dir_a) - ref_output = self._engine_generate(engine) - self._assert_behavior(engine, ref_output, "equal_trivial") - - print(f"Action: eplb_save_expert_distribution") - engine.eplb_save_expert_distribution() - snapshot_path = ExpertDistributionStorage.get_last_snapshot_path( - eplb_storage_dir_a - ) - assert snapshot_path is not None - - print(f"Action: shutdown engine") - engine.shutdown() - del engine - - print(f"Action: start engine with init_expert_location") - engine = sgl.Engine( - **engine_kwargs, - eplb_storage_dir=eplb_storage_dir_b, - init_expert_location=str(snapshot_path), - ) - self._assert_behavior(engine, ref_output, "not_equal_trivial") - print(f"Action: shutdown engine") - engine.shutdown() - del engine - - print( - f"Action: start engine to check automatically loading from storage dir" - ) - engine = sgl.Engine(**engine_kwargs, eplb_storage_dir=eplb_storage_dir_a) - self._assert_behavior(engine, ref_output, "not_equal_trivial") - print(f"Action: shutdown engine") - engine.shutdown() - del engine + # def test_eplb_init_expert_location_and_save_expert_distribution(self): + # with tempfile.TemporaryDirectory() as eplb_storage_dir_a, tempfile.TemporaryDirectory() as eplb_storage_dir_b: + # engine_kwargs = dict( + # model_path=DEFAULT_MLA_MODEL_NAME_FOR_TEST, + # trust_remote_code=True, + # enable_eplb=True, + # ep_num_redundant_experts=_EP_NUM_REDUNDANT_EXPERTS, + # enable_deepep_moe=True, + # deepep_mode="normal", + # disable_cuda_graph=True, + # enable_scheduler_input_blocker=True, + # tp_size=2, + # ) + # + # print(f"Action: start engine") + # engine = sgl.Engine(**engine_kwargs, eplb_storage_dir=eplb_storage_dir_a) + # ref_output = self._engine_generate(engine) + # self._assert_behavior(engine, ref_output, "equal_trivial") + # + # print(f"Action: eplb_save_expert_distribution") + # engine.eplb_save_expert_distribution() + # snapshot_path = ExpertDistributionStorage.get_last_snapshot_path( + # eplb_storage_dir_a + # ) + # assert snapshot_path is not None + # + # print(f"Action: shutdown engine") + # engine.shutdown() + # del engine + # + # print(f"Action: start engine with init_expert_location") + # engine = sgl.Engine( + # **engine_kwargs, + # eplb_storage_dir=eplb_storage_dir_b, + # init_expert_location=str(snapshot_path), + # ) + # self._assert_behavior(engine, ref_output, "not_equal_trivial") + # print(f"Action: shutdown engine") + # engine.shutdown() + # del engine + # + # print( + # f"Action: start engine to check automatically loading from storage dir" + # ) + # engine = sgl.Engine(**engine_kwargs, eplb_storage_dir=eplb_storage_dir_a) + # self._assert_behavior(engine, ref_output, "not_equal_trivial") + # print(f"Action: shutdown engine") + # engine.shutdown() + # del engine def _assert_behavior( self, engine: sgl.Engine, ref_output: List[str], expect_physical_to_local_map @@ -125,14 +130,10 @@ def _assert_behavior( print(f"{physical_to_logical_map_layer_0=}") if expect_physical_to_local_map == "equal_trivial": - self.assertEqual( - physical_to_logical_map_layer_0, - list(range(_NUM_OVERALL_PHYSICAL_EXPERTS)), - ) + self.assertEqual(physical_to_logical_map_layer_0, _TRIVIAL_EXPERT_LOCATIONS) elif expect_physical_to_local_map == "not_equal_trivial": self.assertNotEqual( - physical_to_logical_map_layer_0, - list(range(_NUM_OVERALL_PHYSICAL_EXPERTS)), + physical_to_logical_map_layer_0, _TRIVIAL_EXPERT_LOCATIONS ) else: self.assertEqual( From dcc6ef4987e70d3385fe9940a413a34339594326 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 10:31:12 +0800 Subject: [PATCH 0553/1089] more --- test/srt/test_eplb.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index 794876bdf97..0026b998818 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -34,6 +34,7 @@ def test_eplb_e2e(self): deepep_mode="normal", disable_cuda_graph=True, enable_scheduler_input_blocker=True, + disable_overlap_schedule=True, # TODO tp_size=2, ) @@ -75,6 +76,7 @@ def test_eplb_e2e(self): # deepep_mode="normal", # disable_cuda_graph=True, # enable_scheduler_input_blocker=True, + # disable_overlap_schedule=True, # TODO # tp_size=2, # ) # From 3fac8d9ce8fea649165e25bad67b1ba8d4fa2f16 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 10:32:06 +0800 Subject: [PATCH 0554/1089] more --- test/srt/test_eplb.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index 0026b998818..79942e4a94a 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -75,8 +75,6 @@ def test_eplb_e2e(self): # enable_deepep_moe=True, # deepep_mode="normal", # disable_cuda_graph=True, - # enable_scheduler_input_blocker=True, - # disable_overlap_schedule=True, # TODO # tp_size=2, # ) # From 7abbe6c3fd5776f0d282570ea919d6c0b7f6eef5 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 10:35:48 +0800 Subject: [PATCH 0555/1089] more --- python/sglang/srt/managers/eplb_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index b9336e45821..ee544801e8e 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -40,10 +40,10 @@ async def handle_loop(self): f"Sleep {sleep_time} seconds before automatically trigger rebalancing" ) await asyncio.sleep(sleep_time) - self.save_expert_distribution() await self.rebalance() async def rebalance(self): + self.save_expert_distribution() expert_location_metadata = self.compute_expert_location_metadata() await self._tokenizer_manager.update_expert_location( UpdateExpertLocationReqInput( From 566ed9afb36844efd8dc2e6b78376fe203581a57 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 10:39:26 +0800 Subject: [PATCH 0556/1089] more --- python/sglang/srt/managers/eplb_manager.py | 6 +++--- python/sglang/srt/managers/tokenizer_manager.py | 7 +++---- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index ee544801e8e..84ab35f972a 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -43,7 +43,7 @@ async def handle_loop(self): await self.rebalance() async def rebalance(self): - self.save_expert_distribution() + await self.save_expert_distribution() expert_location_metadata = self.compute_expert_location_metadata() await self._tokenizer_manager.update_expert_location( UpdateExpertLocationReqInput( @@ -51,8 +51,8 @@ async def rebalance(self): ) ) - def save_expert_distribution(self): - self._expert_distribution_storage.save_current() + async def save_expert_distribution(self): + await self._expert_distribution_storage.save_current() def compute_expert_location_metadata(self): snapshot = self._expert_distribution_storage.get_last_snapshot() diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 19099a2d866..d48c0810674 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -45,7 +45,6 @@ import zmq import zmq.asyncio from fastapi import BackgroundTasks - from sglang.srt.aio_rwlock import RWLock from sglang.srt.configs.model_config import ModelConfig from sglang.srt.disaggregation.conn import KVBootstrapServer @@ -701,7 +700,7 @@ async def eplb_rebalance(self): async def eplb_save_expert_distribution(self): self.auto_create_handle_loop() - self.eplb_manager.save_expert_distribution() + await self.eplb_manager.save_expert_distribution() async def update_expert_location(self, obj: UpdateExpertLocationReqInput): self.auto_create_handle_loop() @@ -1026,8 +1025,8 @@ def _handle_batch_output( elif isinstance(recv_obj, BatchTokenIDOut): if self.server_args.stream_output and state.obj.stream: output_token_ids = recv_obj.output_ids[i][ - state.last_output_offset : - ] + state.last_output_offset: + ] state.last_output_offset = len(recv_obj.output_ids[i]) else: output_token_ids = recv_obj.output_ids[i] From f54b8b4da5070b11433463e2c25e186cf73d7c7d Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 10:43:43 +0800 Subject: [PATCH 0557/1089] more --- python/sglang/srt/managers/expert_location.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index da07e07afd6..e3c7b3d953c 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -72,6 +72,8 @@ def init_by_mapping( @staticmethod def init_by_eplb(server_args: ServerArgs, logical_count: torch.Tensor): + if not isinstance(logical_count, torch.Tensor): + logical_count = torch.tensor(logical_count) common = ExpertLocationMetadata._init_common(server_args) model_config_for_expert_location = common["model_config_for_expert_location"] From b53e20ecaa85db1fdcd7e466f8d2182f9c3ab574 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 10:43:51 +0800 Subject: [PATCH 0558/1089] fmt --- python/sglang/srt/managers/tokenizer_manager.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index d48c0810674..2d08fdb8ebb 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -45,6 +45,7 @@ import zmq import zmq.asyncio from fastapi import BackgroundTasks + from sglang.srt.aio_rwlock import RWLock from sglang.srt.configs.model_config import ModelConfig from sglang.srt.disaggregation.conn import KVBootstrapServer @@ -1025,8 +1026,8 @@ def _handle_batch_output( elif isinstance(recv_obj, BatchTokenIDOut): if self.server_args.stream_output and state.obj.stream: output_token_ids = recv_obj.output_ids[i][ - state.last_output_offset: - ] + state.last_output_offset : + ] state.last_output_offset = len(recv_obj.output_ids[i]) else: output_token_ids = recv_obj.output_ids[i] From 5b73b748206b4e676147bdd2ffbf7a1bd2a61379 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 10:43:54 +0800 Subject: [PATCH 0559/1089] more --- test/srt/test_eplb.py | 130 +++++++++++++++++++++--------------------- 1 file changed, 65 insertions(+), 65 deletions(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index 79942e4a94a..4d36230dd7e 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -22,98 +22,98 @@ class TestEPLB(CustomTestCase): - def test_eplb_e2e(self): - with tempfile.TemporaryDirectory() as tmpdir: - engine_kwargs = dict( - model_path=DEFAULT_MLA_MODEL_NAME_FOR_TEST, - trust_remote_code=True, - enable_eplb=True, - eplb_storage_dir=tmpdir, - ep_num_redundant_experts=_EP_NUM_REDUNDANT_EXPERTS, - enable_deepep_moe=True, - deepep_mode="normal", - disable_cuda_graph=True, - enable_scheduler_input_blocker=True, - disable_overlap_schedule=True, # TODO - tp_size=2, - ) - - print(f"Action: start engine") - engine = sgl.Engine(**engine_kwargs) - ref_output = self._engine_generate(engine) - self._assert_behavior(engine, ref_output, "equal_trivial") - - print(f"Action: eplb_rebalance") - engine.eplb_rebalance() - physical_to_logical_map_layer_0_after_first_rebalance = ( - self._assert_behavior(engine, ref_output, "not_equal_trivial") - ) - - print(f"Action: shutdown engine") - engine.shutdown() - del engine - - print(f"Action: start engine") - engine = sgl.Engine(**engine_kwargs) - self._assert_behavior( - engine, - ref_output, - physical_to_logical_map_layer_0_after_first_rebalance, - ) - - print(f"Action: shutdown engine") - engine.shutdown() - del engine - - # def test_eplb_init_expert_location_and_save_expert_distribution(self): - # with tempfile.TemporaryDirectory() as eplb_storage_dir_a, tempfile.TemporaryDirectory() as eplb_storage_dir_b: + # def test_eplb_e2e(self): + # with tempfile.TemporaryDirectory() as tmpdir: # engine_kwargs = dict( # model_path=DEFAULT_MLA_MODEL_NAME_FOR_TEST, # trust_remote_code=True, # enable_eplb=True, + # eplb_storage_dir=tmpdir, # ep_num_redundant_experts=_EP_NUM_REDUNDANT_EXPERTS, # enable_deepep_moe=True, # deepep_mode="normal", # disable_cuda_graph=True, + # enable_scheduler_input_blocker=True, + # disable_overlap_schedule=True, # TODO # tp_size=2, # ) # # print(f"Action: start engine") - # engine = sgl.Engine(**engine_kwargs, eplb_storage_dir=eplb_storage_dir_a) + # engine = sgl.Engine(**engine_kwargs) # ref_output = self._engine_generate(engine) # self._assert_behavior(engine, ref_output, "equal_trivial") # - # print(f"Action: eplb_save_expert_distribution") - # engine.eplb_save_expert_distribution() - # snapshot_path = ExpertDistributionStorage.get_last_snapshot_path( - # eplb_storage_dir_a + # print(f"Action: eplb_rebalance") + # engine.eplb_rebalance() + # physical_to_logical_map_layer_0_after_first_rebalance = ( + # self._assert_behavior(engine, ref_output, "not_equal_trivial") # ) - # assert snapshot_path is not None # # print(f"Action: shutdown engine") # engine.shutdown() # del engine # - # print(f"Action: start engine with init_expert_location") - # engine = sgl.Engine( - # **engine_kwargs, - # eplb_storage_dir=eplb_storage_dir_b, - # init_expert_location=str(snapshot_path), + # print(f"Action: start engine") + # engine = sgl.Engine(**engine_kwargs) + # self._assert_behavior( + # engine, + # ref_output, + # physical_to_logical_map_layer_0_after_first_rebalance, # ) - # self._assert_behavior(engine, ref_output, "not_equal_trivial") - # print(f"Action: shutdown engine") - # engine.shutdown() - # del engine # - # print( - # f"Action: start engine to check automatically loading from storage dir" - # ) - # engine = sgl.Engine(**engine_kwargs, eplb_storage_dir=eplb_storage_dir_a) - # self._assert_behavior(engine, ref_output, "not_equal_trivial") # print(f"Action: shutdown engine") # engine.shutdown() # del engine + def test_eplb_init_expert_location_and_save_expert_distribution(self): + with tempfile.TemporaryDirectory() as eplb_storage_dir_a, tempfile.TemporaryDirectory() as eplb_storage_dir_b: + engine_kwargs = dict( + model_path=DEFAULT_MLA_MODEL_NAME_FOR_TEST, + trust_remote_code=True, + enable_eplb=True, + ep_num_redundant_experts=_EP_NUM_REDUNDANT_EXPERTS, + enable_deepep_moe=True, + deepep_mode="normal", + disable_cuda_graph=True, + tp_size=2, + ) + + print(f"Action: start engine") + engine = sgl.Engine(**engine_kwargs, eplb_storage_dir=eplb_storage_dir_a) + ref_output = self._engine_generate(engine) + self._assert_behavior(engine, ref_output, "equal_trivial") + + print(f"Action: eplb_save_expert_distribution") + engine.eplb_save_expert_distribution() + snapshot_path = ExpertDistributionStorage.get_last_snapshot_path( + eplb_storage_dir_a + ) + assert snapshot_path is not None + + print(f"Action: shutdown engine") + engine.shutdown() + del engine + + print(f"Action: start engine with init_expert_location") + engine = sgl.Engine( + **engine_kwargs, + eplb_storage_dir=eplb_storage_dir_b, + init_expert_location=str(snapshot_path), + ) + self._assert_behavior(engine, ref_output, "not_equal_trivial") + print(f"Action: shutdown engine") + engine.shutdown() + del engine + + print( + f"Action: start engine to check automatically loading from storage dir" + ) + engine = sgl.Engine(**engine_kwargs, eplb_storage_dir=eplb_storage_dir_a) + self._assert_behavior(engine, ref_output, "not_equal_trivial") + print(f"Action: shutdown engine") + engine.shutdown() + del engine + def _assert_behavior( self, engine: sgl.Engine, ref_output: List[str], expect_physical_to_local_map ): From f44b37b3d934755b4dd1091676c8074cc1b7b8a3 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 10:46:25 +0800 Subject: [PATCH 0560/1089] more --- test/srt/test_eplb.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index 4d36230dd7e..9b02c0eab8a 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -36,6 +36,7 @@ class TestEPLB(CustomTestCase): # enable_scheduler_input_blocker=True, # disable_overlap_schedule=True, # TODO # tp_size=2, + # log_level="info", # ) # # print(f"Action: start engine") @@ -76,6 +77,7 @@ def test_eplb_init_expert_location_and_save_expert_distribution(self): deepep_mode="normal", disable_cuda_graph=True, tp_size=2, + log_level="info", ) print(f"Action: start engine") From 23d571907939943709aa80cf89464b2815b91af2 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 10:48:42 +0800 Subject: [PATCH 0561/1089] more --- python/sglang/srt/utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 481cfadf644..0711625f68a 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -1790,9 +1790,13 @@ def retry( return fn() except Exception as e: if try_index >= max_retry: + logger.warning(f"retry() observe error: {e}") + traceback.print_exc() raise Exception(f"retry() exceed maximum number of retries.") if not should_retry(e): + logger.warning(f"retry() observe error: {e}") + traceback.print_exc() raise Exception(f"retry() observe errors that should not be retried.") delay = min(initial_delay * (2**try_index), max_delay) * ( From 8ef9916a1093bf6fce07056eb225de6050ad2d18 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 10:53:01 +0800 Subject: [PATCH 0562/1089] more --- test/srt/test_eplb.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index 9b02c0eab8a..c7775b5887d 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -1,5 +1,6 @@ import tempfile import unittest +from pathlib import Path from typing import List import sglang as sgl @@ -88,7 +89,7 @@ def test_eplb_init_expert_location_and_save_expert_distribution(self): print(f"Action: eplb_save_expert_distribution") engine.eplb_save_expert_distribution() snapshot_path = ExpertDistributionStorage.get_last_snapshot_path( - eplb_storage_dir_a + Path(eplb_storage_dir_a) / "expert_distribution_storage" ) assert snapshot_path is not None From 1bdc7c2370347b8dc78c680be6634f90202abf4c Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 10:54:57 +0800 Subject: [PATCH 0563/1089] more --- python/sglang/srt/managers/eplb_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index 84ab35f972a..36f842e0b67 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -37,7 +37,7 @@ async def handle_loop(self): while True: sleep_time = self._server_args.eplb_rebalance_period or 1000000000 logger.info( - f"Sleep {sleep_time} seconds before automatically trigger rebalancing" + f"EPLBManager: Sleep {sleep_time} seconds before next automatic rebalancing" ) await asyncio.sleep(sleep_time) await self.rebalance() From 82f7e5a3116108866fe00b3eba244de1b13843a2 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 10:55:20 +0800 Subject: [PATCH 0564/1089] more --- test/srt/test_eplb.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index c7775b5887d..aaf519b446b 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -92,6 +92,7 @@ def test_eplb_init_expert_location_and_save_expert_distribution(self): Path(eplb_storage_dir_a) / "expert_distribution_storage" ) assert snapshot_path is not None + print(f"{snapshot_path.read_text()=}") print(f"Action: shutdown engine") engine.shutdown() @@ -118,7 +119,7 @@ def test_eplb_init_expert_location_and_save_expert_distribution(self): del engine def _assert_behavior( - self, engine: sgl.Engine, ref_output: List[str], expect_physical_to_local_map + self, engine: sgl.Engine, ref_output: List[str], expect_physical_to_local_map ): ret = engine.flush_cache() assert ret.success From a1ecc17439eda4f63dd520b808924f3fe9fc0ae2 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 10:55:45 +0800 Subject: [PATCH 0565/1089] more --- test/srt/test_eplb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index aaf519b446b..6d603eddf82 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -92,7 +92,7 @@ def test_eplb_init_expert_location_and_save_expert_distribution(self): Path(eplb_storage_dir_a) / "expert_distribution_storage" ) assert snapshot_path is not None - print(f"{snapshot_path.read_text()=}") + print(f"{snapshot_path=} {snapshot_path.read_text()=}") print(f"Action: shutdown engine") engine.shutdown() From 845914507bee86d4fdb3f99ebc4910e39bb706bf Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 10:58:10 +0800 Subject: [PATCH 0566/1089] more --- python/sglang/srt/managers/expert_location.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index e3c7b3d953c..d9cf94a1dd5 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -5,7 +5,6 @@ from typing import List, Optional import torch - from sglang.srt.configs.model_config import ModelConfig from sglang.srt.managers import deepseek_eplb from sglang.srt.model_loader import get_model_architecture @@ -34,8 +33,8 @@ def init_trivial(server_args: ServerArgs): num_logical_experts = model_config_for_expert_location.num_logical_experts physical_to_logical_map = ( - torch.arange(0, num_physical_experts).repeat(num_layers, 1) - % num_logical_experts + torch.arange(0, num_physical_experts).repeat(num_layers, 1) + % num_logical_experts ) # Highly inefficient, but we do not care since we will use EPLB distribution logic logical_to_all_physical_map = torch.arange(0, num_logical_experts).repeat( @@ -50,7 +49,7 @@ def init_trivial(server_args: ServerArgs): @staticmethod def init_by_mapping( - server_args: ServerArgs, physical_to_logical_map, logical_to_all_physical_map + server_args: ServerArgs, physical_to_logical_map, logical_to_all_physical_map ): if not isinstance(physical_to_logical_map, torch.Tensor): physical_to_logical_map = torch.tensor(physical_to_logical_map) @@ -107,8 +106,8 @@ def _init_common(server_args: ServerArgs): ) num_physical_experts = ( - model_config_for_expert_location.num_logical_experts - + server_args.ep_num_redundant_experts + model_config_for_expert_location.num_logical_experts + + server_args.ep_num_redundant_experts ) # TODO consider case when DP attention is disabled and DP > 1 world_size = server_args.tp_size @@ -156,7 +155,7 @@ def physical_to_local_physical(self, global_physical_expert_index: int): return global_physical_expert_index % self.num_local_physical_experts def logical_to_all_physical( - self, layer_id: int, logical_expert_id: int + self, layer_id: int, logical_expert_id: int ) -> List[int]: return self.logical_to_all_physical_raw( self.logical_to_all_physical_map, layer_id, logical_expert_id @@ -164,7 +163,7 @@ def logical_to_all_physical( @staticmethod def logical_to_all_physical_raw( - logical_to_all_physical_map, layer_id: int, logical_expert_id: int + logical_to_all_physical_map, layer_id: int, logical_expert_id: int ) -> List[int]: return [ physical_expert_id @@ -184,8 +183,8 @@ def debug_str(self): def _compute_logical_to_rank_dispatch_physical_map( - logical_to_all_physical_map: torch.Tensor, - num_gpus: int, + logical_to_all_physical_map: torch.Tensor, + num_gpus: int, ): # TODO maybe improve this algorithm (e.g. ensure it is really balanced) # This is rarely called, so we use for loops for maximum clarity @@ -194,7 +193,8 @@ def _compute_logical_to_rank_dispatch_physical_map( num_layers, num_logical_experts, _ = logical_to_all_physical_map.shape logical_to_rank_dispatch_physical_map = torch.zeros( - (num_gpus, num_layers, num_logical_experts) + (num_gpus, num_layers, num_logical_experts), + dtype=logical_to_all_physical_map.dtype, ) for layer_id in range(num_layers): From dc4055a97b3df344afcbb501e7dffac9d9b9bc98 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 10:59:12 +0800 Subject: [PATCH 0567/1089] more --- python/sglang/srt/entrypoints/engine.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 6c945b884f6..6476019d992 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -26,12 +26,12 @@ import os import signal import threading +from pathlib import Path from typing import AsyncIterator, Dict, Iterator, List, Optional, Tuple, Union import zmq import zmq.asyncio from PIL.Image import Image - from sglang.srt.configs.model_config import ModelConfig from sglang.srt.managers.eplb_manager import EPLBManager from sglang.srt.managers.expert_location import ExpertLocationMetadata @@ -649,7 +649,7 @@ def _compute_initial_expert_location_metadata( server_args: ServerArgs, eplb_manager: EPLBManager ) -> ExpertLocationMetadata: if (data := server_args.init_expert_location) is not None: - data_dict = json.loads(data) + data_dict = json.loads(Path(data).read_text()) if "physical_to_logical_map" in data_dict: # TODO We may want to allow users to not provide `logical_to_all_physical_map` if this API is frequently used return ExpertLocationMetadata.init_by_mapping(server_args, **data_dict) From be22f63498f1e55c6b10d1d179164c6fc56de236 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 10:59:55 +0800 Subject: [PATCH 0568/1089] fmt --- python/sglang/srt/entrypoints/engine.py | 1 + python/sglang/srt/managers/expert_location.py | 19 ++++++++++--------- test/srt/test_eplb.py | 2 +- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 6476019d992..a3934b33512 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -32,6 +32,7 @@ import zmq import zmq.asyncio from PIL.Image import Image + from sglang.srt.configs.model_config import ModelConfig from sglang.srt.managers.eplb_manager import EPLBManager from sglang.srt.managers.expert_location import ExpertLocationMetadata diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index d9cf94a1dd5..f478d5886bc 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -5,6 +5,7 @@ from typing import List, Optional import torch + from sglang.srt.configs.model_config import ModelConfig from sglang.srt.managers import deepseek_eplb from sglang.srt.model_loader import get_model_architecture @@ -33,8 +34,8 @@ def init_trivial(server_args: ServerArgs): num_logical_experts = model_config_for_expert_location.num_logical_experts physical_to_logical_map = ( - torch.arange(0, num_physical_experts).repeat(num_layers, 1) - % num_logical_experts + torch.arange(0, num_physical_experts).repeat(num_layers, 1) + % num_logical_experts ) # Highly inefficient, but we do not care since we will use EPLB distribution logic logical_to_all_physical_map = torch.arange(0, num_logical_experts).repeat( @@ -49,7 +50,7 @@ def init_trivial(server_args: ServerArgs): @staticmethod def init_by_mapping( - server_args: ServerArgs, physical_to_logical_map, logical_to_all_physical_map + server_args: ServerArgs, physical_to_logical_map, logical_to_all_physical_map ): if not isinstance(physical_to_logical_map, torch.Tensor): physical_to_logical_map = torch.tensor(physical_to_logical_map) @@ -106,8 +107,8 @@ def _init_common(server_args: ServerArgs): ) num_physical_experts = ( - model_config_for_expert_location.num_logical_experts - + server_args.ep_num_redundant_experts + model_config_for_expert_location.num_logical_experts + + server_args.ep_num_redundant_experts ) # TODO consider case when DP attention is disabled and DP > 1 world_size = server_args.tp_size @@ -155,7 +156,7 @@ def physical_to_local_physical(self, global_physical_expert_index: int): return global_physical_expert_index % self.num_local_physical_experts def logical_to_all_physical( - self, layer_id: int, logical_expert_id: int + self, layer_id: int, logical_expert_id: int ) -> List[int]: return self.logical_to_all_physical_raw( self.logical_to_all_physical_map, layer_id, logical_expert_id @@ -163,7 +164,7 @@ def logical_to_all_physical( @staticmethod def logical_to_all_physical_raw( - logical_to_all_physical_map, layer_id: int, logical_expert_id: int + logical_to_all_physical_map, layer_id: int, logical_expert_id: int ) -> List[int]: return [ physical_expert_id @@ -183,8 +184,8 @@ def debug_str(self): def _compute_logical_to_rank_dispatch_physical_map( - logical_to_all_physical_map: torch.Tensor, - num_gpus: int, + logical_to_all_physical_map: torch.Tensor, + num_gpus: int, ): # TODO maybe improve this algorithm (e.g. ensure it is really balanced) # This is rarely called, so we use for loops for maximum clarity diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index 6d603eddf82..d664ad46c5b 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -119,7 +119,7 @@ def test_eplb_init_expert_location_and_save_expert_distribution(self): del engine def _assert_behavior( - self, engine: sgl.Engine, ref_output: List[str], expect_physical_to_local_map + self, engine: sgl.Engine, ref_output: List[str], expect_physical_to_local_map ): ret = engine.flush_cache() assert ret.success From 28fe2934405e8c5602339f91272f8a0fed4e5f66 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 11:14:20 +0800 Subject: [PATCH 0569/1089] more --- test/srt/test_eplb.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index d664ad46c5b..ac7e9caaa82 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -47,6 +47,7 @@ class TestEPLB(CustomTestCase): # # print(f"Action: eplb_rebalance") # engine.eplb_rebalance() + # self._engine_flush_cache(engine) # physical_to_logical_map_layer_0_after_first_rebalance = ( # self._assert_behavior(engine, ref_output, "not_equal_trivial") # ) @@ -119,11 +120,8 @@ def test_eplb_init_expert_location_and_save_expert_distribution(self): del engine def _assert_behavior( - self, engine: sgl.Engine, ref_output: List[str], expect_physical_to_local_map + self, engine: sgl.Engine, ref_output: List[str], expect_physical_to_local_map ): - ret = engine.flush_cache() - assert ret.success - actual_output = self._engine_generate(engine) self.assertEqual(actual_output, ref_output) @@ -154,6 +152,10 @@ def _engine_generate(self, engine: sgl.Engine): print(f"engine_generate {output=}") return [x["text"] for x in output] + def _engine_flush_cache(self, engine: sgl.Engine): + ret = engine.flush_cache() + assert ret.success + if __name__ == "__main__": unittest.main() From a6e6cbbdebbe59d53ce89ea5761c52cc2b5b03bb Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 11:25:52 +0800 Subject: [PATCH 0570/1089] more --- test/srt/test_eplb.py | 88 +++++++++++++++++++++---------------------- 1 file changed, 44 insertions(+), 44 deletions(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index ac7e9caaa82..dc4d7d12f92 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -23,50 +23,50 @@ class TestEPLB(CustomTestCase): - # def test_eplb_e2e(self): - # with tempfile.TemporaryDirectory() as tmpdir: - # engine_kwargs = dict( - # model_path=DEFAULT_MLA_MODEL_NAME_FOR_TEST, - # trust_remote_code=True, - # enable_eplb=True, - # eplb_storage_dir=tmpdir, - # ep_num_redundant_experts=_EP_NUM_REDUNDANT_EXPERTS, - # enable_deepep_moe=True, - # deepep_mode="normal", - # disable_cuda_graph=True, - # enable_scheduler_input_blocker=True, - # disable_overlap_schedule=True, # TODO - # tp_size=2, - # log_level="info", - # ) - # - # print(f"Action: start engine") - # engine = sgl.Engine(**engine_kwargs) - # ref_output = self._engine_generate(engine) - # self._assert_behavior(engine, ref_output, "equal_trivial") - # - # print(f"Action: eplb_rebalance") - # engine.eplb_rebalance() - # self._engine_flush_cache(engine) - # physical_to_logical_map_layer_0_after_first_rebalance = ( - # self._assert_behavior(engine, ref_output, "not_equal_trivial") - # ) - # - # print(f"Action: shutdown engine") - # engine.shutdown() - # del engine - # - # print(f"Action: start engine") - # engine = sgl.Engine(**engine_kwargs) - # self._assert_behavior( - # engine, - # ref_output, - # physical_to_logical_map_layer_0_after_first_rebalance, - # ) - # - # print(f"Action: shutdown engine") - # engine.shutdown() - # del engine + def test_eplb_e2e(self): + with tempfile.TemporaryDirectory() as tmpdir: + engine_kwargs = dict( + model_path=DEFAULT_MLA_MODEL_NAME_FOR_TEST, + trust_remote_code=True, + enable_eplb=True, + eplb_storage_dir=tmpdir, + ep_num_redundant_experts=_EP_NUM_REDUNDANT_EXPERTS, + enable_deepep_moe=True, + deepep_mode="normal", + disable_cuda_graph=True, + enable_scheduler_input_blocker=True, + disable_overlap_schedule=True, # TODO + tp_size=2, + log_level="info", + ) + + print(f"Action: start engine") + engine = sgl.Engine(**engine_kwargs) + ref_output = self._engine_generate(engine) + self._assert_behavior(engine, ref_output, "equal_trivial") + + print(f"Action: eplb_rebalance") + engine.eplb_rebalance() + self._engine_flush_cache(engine) + physical_to_logical_map_layer_0_after_first_rebalance = ( + self._assert_behavior(engine, ref_output, "not_equal_trivial") + ) + + print(f"Action: shutdown engine") + engine.shutdown() + del engine + + print(f"Action: start engine") + engine = sgl.Engine(**engine_kwargs) + self._assert_behavior( + engine, + ref_output, + physical_to_logical_map_layer_0_after_first_rebalance, + ) + + print(f"Action: shutdown engine") + engine.shutdown() + del engine def test_eplb_init_expert_location_and_save_expert_distribution(self): with tempfile.TemporaryDirectory() as eplb_storage_dir_a, tempfile.TemporaryDirectory() as eplb_storage_dir_b: From bfef08df1a5e814f7da13f8101df638fe718dc19 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 11:26:51 +0800 Subject: [PATCH 0571/1089] more --- test/srt/test_eplb.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index dc4d7d12f92..1b8d7da06f4 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -23,7 +23,7 @@ class TestEPLB(CustomTestCase): - def test_eplb_e2e(self): + def _tempdisable_test_eplb_e2e(self): with tempfile.TemporaryDirectory() as tmpdir: engine_kwargs = dict( model_path=DEFAULT_MLA_MODEL_NAME_FOR_TEST, @@ -68,7 +68,7 @@ def test_eplb_e2e(self): engine.shutdown() del engine - def test_eplb_init_expert_location_and_save_expert_distribution(self): + def _tempdisable_test_eplb_init_expert_location_and_save_expert_distribution(self): with tempfile.TemporaryDirectory() as eplb_storage_dir_a, tempfile.TemporaryDirectory() as eplb_storage_dir_b: engine_kwargs = dict( model_path=DEFAULT_MLA_MODEL_NAME_FOR_TEST, From 64e7df417881da81788481a0e43ae0caf13fd480 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 11:27:59 +0800 Subject: [PATCH 0572/1089] more --- python/sglang/srt/entrypoints/engine.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index a3934b33512..a7cd13d7142 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -26,13 +26,13 @@ import os import signal import threading +from json import JSONDecodeError from pathlib import Path from typing import AsyncIterator, Dict, Iterator, List, Optional, Tuple, Union import zmq import zmq.asyncio from PIL.Image import Image - from sglang.srt.configs.model_config import ModelConfig from sglang.srt.managers.eplb_manager import EPLBManager from sglang.srt.managers.expert_location import ExpertLocationMetadata @@ -650,7 +650,11 @@ def _compute_initial_expert_location_metadata( server_args: ServerArgs, eplb_manager: EPLBManager ) -> ExpertLocationMetadata: if (data := server_args.init_expert_location) is not None: - data_dict = json.loads(Path(data).read_text()) + try: + data_dict = json.loads(data) + except JSONDecodeError: + data_dict = json.loads(Path(data).read_text()) + if "physical_to_logical_map" in data_dict: # TODO We may want to allow users to not provide `logical_to_all_physical_map` if this API is frequently used return ExpertLocationMetadata.init_by_mapping(server_args, **data_dict) From 213387eaf089be270ace7d934b52f86317cf5d80 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 11:30:28 +0800 Subject: [PATCH 0573/1089] more --- python/sglang/srt/managers/expert_location.py | 29 +++++++++---------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index f478d5886bc..0de9cb8c700 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -5,7 +5,6 @@ from typing import List, Optional import torch - from sglang.srt.configs.model_config import ModelConfig from sglang.srt.managers import deepseek_eplb from sglang.srt.model_loader import get_model_architecture @@ -34,29 +33,25 @@ def init_trivial(server_args: ServerArgs): num_logical_experts = model_config_for_expert_location.num_logical_experts physical_to_logical_map = ( - torch.arange(0, num_physical_experts).repeat(num_layers, 1) - % num_logical_experts + torch.arange(0, num_physical_experts).repeat(num_layers, 1) + % num_logical_experts ) - # Highly inefficient, but we do not care since we will use EPLB distribution logic - logical_to_all_physical_map = torch.arange(0, num_logical_experts).repeat( - num_layers, 1 - )[..., None] return ExpertLocationMetadata.init_by_mapping( server_args, physical_to_logical_map=physical_to_logical_map, - logical_to_all_physical_map=logical_to_all_physical_map, ) @staticmethod def init_by_mapping( - server_args: ServerArgs, physical_to_logical_map, logical_to_all_physical_map + server_args: ServerArgs, physical_to_logical_map ): if not isinstance(physical_to_logical_map, torch.Tensor): physical_to_logical_map = torch.tensor(physical_to_logical_map) common = ExpertLocationMetadata._init_common(server_args) model_config_for_expert_location = common["model_config_for_expert_location"] + logical_to_all_physical_map = _compute_logical_to_all_physical_map(physical_to_logical_map) return ExpertLocationMetadata( num_layers=model_config_for_expert_location.num_layers, @@ -107,8 +102,8 @@ def _init_common(server_args: ServerArgs): ) num_physical_experts = ( - model_config_for_expert_location.num_logical_experts - + server_args.ep_num_redundant_experts + model_config_for_expert_location.num_logical_experts + + server_args.ep_num_redundant_experts ) # TODO consider case when DP attention is disabled and DP > 1 world_size = server_args.tp_size @@ -156,7 +151,7 @@ def physical_to_local_physical(self, global_physical_expert_index: int): return global_physical_expert_index % self.num_local_physical_experts def logical_to_all_physical( - self, layer_id: int, logical_expert_id: int + self, layer_id: int, logical_expert_id: int ) -> List[int]: return self.logical_to_all_physical_raw( self.logical_to_all_physical_map, layer_id, logical_expert_id @@ -164,7 +159,7 @@ def logical_to_all_physical( @staticmethod def logical_to_all_physical_raw( - logical_to_all_physical_map, layer_id: int, logical_expert_id: int + logical_to_all_physical_map, layer_id: int, logical_expert_id: int ) -> List[int]: return [ physical_expert_id @@ -183,9 +178,13 @@ def debug_str(self): ) +def _compute_logical_to_all_physical_map(physical_to_logical_map: torch.Tensor): + return TODO + + def _compute_logical_to_rank_dispatch_physical_map( - logical_to_all_physical_map: torch.Tensor, - num_gpus: int, + logical_to_all_physical_map: torch.Tensor, + num_gpus: int, ): # TODO maybe improve this algorithm (e.g. ensure it is really balanced) # This is rarely called, so we use for loops for maximum clarity From 247fd9d02d46b3bf45539c740fe8296282fb257f Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 11:32:14 +0800 Subject: [PATCH 0574/1089] more --- python/sglang/srt/managers/expert_location.py | 20 +++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 0de9cb8c700..3a7e73d83b0 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -51,7 +51,9 @@ def init_by_mapping( common = ExpertLocationMetadata._init_common(server_args) model_config_for_expert_location = common["model_config_for_expert_location"] - logical_to_all_physical_map = _compute_logical_to_all_physical_map(physical_to_logical_map) + logical_to_all_physical_map = _compute_logical_to_all_physical_map(physical_to_logical_map, + num_logical_experts=model_config_for_expert_location.num_logical_experts + ) return ExpertLocationMetadata( num_layers=model_config_for_expert_location.num_layers, @@ -178,7 +180,21 @@ def debug_str(self): ) -def _compute_logical_to_all_physical_map(physical_to_logical_map: torch.Tensor): +def _compute_logical_to_all_physical_map(physical_to_logical_map: torch.Tensor, num_logical_experts: int): + num_layers, num_physical_experts = physical_to_logical_map.shape + + logical_to_all_physical_map = [ + [ + [] + for physical_expert_id in range(num_logical_experts) + ] + for layer_id in range(num_layers) + ] + + for layer_id in range(num_layers): + for physical_expert_id in range(num_physical_experts): + TODO + return TODO From 8fbd02678917f4cc19b41c9b7b396419681b6612 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 11:33:08 +0800 Subject: [PATCH 0575/1089] more --- python/sglang/srt/managers/expert_location.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 3a7e73d83b0..e38dd8cc0ca 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -186,14 +186,15 @@ def _compute_logical_to_all_physical_map(physical_to_logical_map: torch.Tensor, logical_to_all_physical_map = [ [ [] - for physical_expert_id in range(num_logical_experts) + for _logical_expert_id in range(num_logical_experts) ] - for layer_id in range(num_layers) + for _layer_id in range(num_layers) ] for layer_id in range(num_layers): for physical_expert_id in range(num_physical_experts): - TODO + logical_expert_id = physical_to_logical_map[layer_id, physical_expert_id].item() + logical_to_all_physical_map[layer_id][logical_expert_id].append(physical_expert_id) return TODO From a063852fb59efc821c1fd4494991e5dee2755626 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 11:35:01 +0800 Subject: [PATCH 0576/1089] more --- python/sglang/srt/managers/expert_location.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index e38dd8cc0ca..27d14767ef7 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -196,7 +196,18 @@ def _compute_logical_to_all_physical_map(physical_to_logical_map: torch.Tensor, logical_expert_id = physical_to_logical_map[layer_id, physical_expert_id].item() logical_to_all_physical_map[layer_id][logical_expert_id].append(physical_expert_id) - return TODO + max_logical_per_physical = max( + len(logical_to_all_physical_map[layer_id][logical_expert_id]) + for layer_id in range(num_layers) + for logical_expert_id in range(num_logical_experts) + ) + + for layer_id in range(num_layers): + for logical_expert_id in range(num_logical_experts): + target = logical_to_all_physical_map[layer_id][logical_expert_id] + target += [-1] * (max_logical_per_physical - len(target)) + + return torch.tensor(logical_to_all_physical_map) def _compute_logical_to_rank_dispatch_physical_map( From a3509fccd11eed4c692a659137820acc2e1ed33e Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 11:36:23 +0800 Subject: [PATCH 0577/1089] more --- python/sglang/srt/managers/expert_location.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 27d14767ef7..c9803932a82 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -181,15 +181,12 @@ def debug_str(self): def _compute_logical_to_all_physical_map(physical_to_logical_map: torch.Tensor, num_logical_experts: int): + # This is rarely called, so we use for loops for maximum clarity + num_layers, num_physical_experts = physical_to_logical_map.shape - logical_to_all_physical_map = [ - [ - [] - for _logical_expert_id in range(num_logical_experts) - ] - for _layer_id in range(num_layers) - ] + logical_to_all_physical_map = [[[] for _logical_expert_id in range(num_logical_experts)] for _layer_id in + range(num_layers)] for layer_id in range(num_layers): for physical_expert_id in range(num_physical_experts): From 00748f23711282139604ef08bfc6f27c81e37e74 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 11:39:13 +0800 Subject: [PATCH 0578/1089] more --- python/sglang/srt/managers/expert_location.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index c9803932a82..7b09c71b93f 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -193,6 +193,12 @@ def _compute_logical_to_all_physical_map(physical_to_logical_map: torch.Tensor, logical_expert_id = physical_to_logical_map[layer_id, physical_expert_id].item() logical_to_all_physical_map[layer_id][logical_expert_id].append(physical_expert_id) + logical_to_all_physical_map = _pad_nested_array(logical_to_all_physical_map, pad_value=-1) + + return torch.tensor(logical_to_all_physical_map) + + +def _pad_nested_array(arr, pad_value): max_logical_per_physical = max( len(logical_to_all_physical_map[layer_id][logical_expert_id]) for layer_id in range(num_layers) @@ -204,7 +210,7 @@ def _compute_logical_to_all_physical_map(physical_to_logical_map: torch.Tensor, target = logical_to_all_physical_map[layer_id][logical_expert_id] target += [-1] * (max_logical_per_physical - len(target)) - return torch.tensor(logical_to_all_physical_map) + return TODO def _compute_logical_to_rank_dispatch_physical_map( From ee99458736abe00bc43d3aca104f5ed5cc7a77b4 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 11:41:33 +0800 Subject: [PATCH 0579/1089] more --- python/sglang/srt/managers/expert_location.py | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 7b09c71b93f..d9c35e85bcc 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -199,18 +199,9 @@ def _compute_logical_to_all_physical_map(physical_to_logical_map: torch.Tensor, def _pad_nested_array(arr, pad_value): - max_logical_per_physical = max( - len(logical_to_all_physical_map[layer_id][logical_expert_id]) - for layer_id in range(num_layers) - for logical_expert_id in range(num_logical_experts) - ) - - for layer_id in range(num_layers): - for logical_expert_id in range(num_logical_experts): - target = logical_to_all_physical_map[layer_id][logical_expert_id] - target += [-1] * (max_logical_per_physical - len(target)) - - return TODO + max_len = max(len(inner) for outer in arr for inner in outer) + padded = [[[inner + [pad_value] * (max_len - len(inner))] for inner in outer] for outer in arr] + return torch.tensor(padded) def _compute_logical_to_rank_dispatch_physical_map( From a19dbe2761500c8006b4a0205ffa14dcc2f0d040 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 11:41:50 +0800 Subject: [PATCH 0580/1089] more --- python/sglang/srt/managers/expert_location.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index d9c35e85bcc..06f5e314c40 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -185,9 +185,7 @@ def _compute_logical_to_all_physical_map(physical_to_logical_map: torch.Tensor, num_layers, num_physical_experts = physical_to_logical_map.shape - logical_to_all_physical_map = [[[] for _logical_expert_id in range(num_logical_experts)] for _layer_id in - range(num_layers)] - + logical_to_all_physical_map = [[[] for _ in range(num_logical_experts)] for _ in range(num_layers)] for layer_id in range(num_layers): for physical_expert_id in range(num_physical_experts): logical_expert_id = physical_to_logical_map[layer_id, physical_expert_id].item() From 6d991a866cbf65861f93c15d726e2dbf3b6068cb Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 11:43:37 +0800 Subject: [PATCH 0581/1089] more --- test/srt/test_eplb.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index 1b8d7da06f4..1ff5c86989d 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -15,11 +15,6 @@ ) _NUM_ROUTED_EXPERTS = 64 # DeepSeek-Coder-V2-Lite-Instruct -_EP_NUM_REDUNDANT_EXPERTS = 4 -_NUM_OVERALL_PHYSICAL_EXPERTS = _NUM_ROUTED_EXPERTS + _EP_NUM_REDUNDANT_EXPERTS -_TRIVIAL_EXPERT_LOCATIONS = list( - x % _NUM_ROUTED_EXPERTS for x in range(_NUM_OVERALL_PHYSICAL_EXPERTS) -) class TestEPLB(CustomTestCase): @@ -30,7 +25,7 @@ def _tempdisable_test_eplb_e2e(self): trust_remote_code=True, enable_eplb=True, eplb_storage_dir=tmpdir, - ep_num_redundant_experts=_EP_NUM_REDUNDANT_EXPERTS, + ep_num_redundant_experts=4, enable_deepep_moe=True, deepep_mode="normal", disable_cuda_graph=True, @@ -74,7 +69,7 @@ def _tempdisable_test_eplb_init_expert_location_and_save_expert_distribution(sel model_path=DEFAULT_MLA_MODEL_NAME_FOR_TEST, trust_remote_code=True, enable_eplb=True, - ep_num_redundant_experts=_EP_NUM_REDUNDANT_EXPERTS, + ep_num_redundant_experts=4, enable_deepep_moe=True, deepep_mode="normal", disable_cuda_graph=True, @@ -131,12 +126,12 @@ def _assert_behavior( physical_to_logical_map_layer_0 = physical_to_logical_map[0, :].tolist() print(f"{physical_to_logical_map_layer_0=}") + trivial_expert_locations = _compute_trivial_expert_locations(engine.server_args.ep_num_redundant_experts) + if expect_physical_to_local_map == "equal_trivial": - self.assertEqual(physical_to_logical_map_layer_0, _TRIVIAL_EXPERT_LOCATIONS) + self.assertEqual(physical_to_logical_map_layer_0, trivial_expert_locations) elif expect_physical_to_local_map == "not_equal_trivial": - self.assertNotEqual( - physical_to_logical_map_layer_0, _TRIVIAL_EXPERT_LOCATIONS - ) + self.assertNotEqual(physical_to_logical_map_layer_0, trivial_expert_locations) else: self.assertEqual( physical_to_logical_map_layer_0, expect_physical_to_local_map @@ -156,6 +151,9 @@ def _engine_flush_cache(self, engine: sgl.Engine): ret = engine.flush_cache() assert ret.success +def _compute_trivial_expert_locations(ep_num_redundant_experts: int): + return list(x % _NUM_ROUTED_EXPERTS for x in range(_NUM_ROUTED_EXPERTS + ep_num_redundant_experts)) if __name__ == "__main__": unittest.main() + From ff6d45a4013a3cbcfa3544a14267a9ef037ca361 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 11:44:52 +0800 Subject: [PATCH 0582/1089] more --- test/srt/test_eplb.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index 1ff5c86989d..742a27a982d 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -114,6 +114,23 @@ def _tempdisable_test_eplb_init_expert_location_and_save_expert_distribution(sel engine.shutdown() del engine + def test_nontrivial_location(self): + engine_kwargs = dict( + model_path=DEFAULT_MLA_MODEL_NAME_FOR_TEST, + trust_remote_code=True, + ep_num_redundant_experts=4, + enable_deepep_moe=True, + deepep_mode="normal", + disable_cuda_graph=True, + tp_size=2, + log_level="info", + ) + + engine = sgl.Engine(**engine_kwargs) + self._assert_behavior(engine, ref_output, TODO) + engine.shutdown() + del engine + def _assert_behavior( self, engine: sgl.Engine, ref_output: List[str], expect_physical_to_local_map ): From d754a2d1071ad48534cbaa7aa9198883290f694d Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 11:45:47 +0800 Subject: [PATCH 0583/1089] more --- test/srt/test_eplb.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index 742a27a982d..34981d32505 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -15,6 +15,7 @@ ) _NUM_ROUTED_EXPERTS = 64 # DeepSeek-Coder-V2-Lite-Instruct +_REF_OUTPUT = [', 4+4=8,', ', four plus four is eight, eight'] class TestEPLB(CustomTestCase): @@ -37,14 +38,13 @@ def _tempdisable_test_eplb_e2e(self): print(f"Action: start engine") engine = sgl.Engine(**engine_kwargs) - ref_output = self._engine_generate(engine) - self._assert_behavior(engine, ref_output, "equal_trivial") + self._assert_behavior(engine, "equal_trivial") print(f"Action: eplb_rebalance") engine.eplb_rebalance() self._engine_flush_cache(engine) physical_to_logical_map_layer_0_after_first_rebalance = ( - self._assert_behavior(engine, ref_output, "not_equal_trivial") + self._assert_behavior(engine, "not_equal_trivial") ) print(f"Action: shutdown engine") @@ -55,7 +55,6 @@ def _tempdisable_test_eplb_e2e(self): engine = sgl.Engine(**engine_kwargs) self._assert_behavior( engine, - ref_output, physical_to_logical_map_layer_0_after_first_rebalance, ) @@ -79,8 +78,7 @@ def _tempdisable_test_eplb_init_expert_location_and_save_expert_distribution(sel print(f"Action: start engine") engine = sgl.Engine(**engine_kwargs, eplb_storage_dir=eplb_storage_dir_a) - ref_output = self._engine_generate(engine) - self._assert_behavior(engine, ref_output, "equal_trivial") + self._assert_behavior(engine, "equal_trivial") print(f"Action: eplb_save_expert_distribution") engine.eplb_save_expert_distribution() @@ -100,7 +98,7 @@ def _tempdisable_test_eplb_init_expert_location_and_save_expert_distribution(sel eplb_storage_dir=eplb_storage_dir_b, init_expert_location=str(snapshot_path), ) - self._assert_behavior(engine, ref_output, "not_equal_trivial") + self._assert_behavior(engine, "not_equal_trivial") print(f"Action: shutdown engine") engine.shutdown() del engine @@ -109,7 +107,7 @@ def _tempdisable_test_eplb_init_expert_location_and_save_expert_distribution(sel f"Action: start engine to check automatically loading from storage dir" ) engine = sgl.Engine(**engine_kwargs, eplb_storage_dir=eplb_storage_dir_a) - self._assert_behavior(engine, ref_output, "not_equal_trivial") + self._assert_behavior(engine, "not_equal_trivial") print(f"Action: shutdown engine") engine.shutdown() del engine @@ -127,15 +125,15 @@ def test_nontrivial_location(self): ) engine = sgl.Engine(**engine_kwargs) - self._assert_behavior(engine, ref_output, TODO) + self._assert_behavior(engine, TODO) engine.shutdown() del engine def _assert_behavior( - self, engine: sgl.Engine, ref_output: List[str], expect_physical_to_local_map + self, engine: sgl.Engine, expect_physical_to_local_map ): actual_output = self._engine_generate(engine) - self.assertEqual(actual_output, ref_output) + self.assertEqual(actual_output, _REF_OUTPUT) physical_to_logical_map = ( engine.tokenizer_manager.expert_location_metadata.physical_to_logical_map From 4b2c7e94f5294a5cd68f147b9dd2f874b99f0c83 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 11:46:32 +0800 Subject: [PATCH 0584/1089] more --- test/srt/test_eplb.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index 34981d32505..23e133ad1ec 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -1,7 +1,7 @@ +import json import tempfile import unittest from pathlib import Path -from typing import List import sglang as sgl from sglang.srt.managers.expert_distribution_storage import ExpertDistributionStorage @@ -124,13 +124,17 @@ def test_nontrivial_location(self): log_level="info", ) - engine = sgl.Engine(**engine_kwargs) + init_expert_location = dict( + physical_to_logical_map=TODO, + ) + + engine = sgl.Engine(**engine_kwargs, init_expert_location=json.dumps(init_expert_location)) self._assert_behavior(engine, TODO) engine.shutdown() del engine def _assert_behavior( - self, engine: sgl.Engine, expect_physical_to_local_map + self, engine: sgl.Engine, expect_physical_to_local_map ): actual_output = self._engine_generate(engine) self.assertEqual(actual_output, _REF_OUTPUT) @@ -166,9 +170,10 @@ def _engine_flush_cache(self, engine: sgl.Engine): ret = engine.flush_cache() assert ret.success + def _compute_trivial_expert_locations(ep_num_redundant_experts: int): return list(x % _NUM_ROUTED_EXPERTS for x in range(_NUM_ROUTED_EXPERTS + ep_num_redundant_experts)) + if __name__ == "__main__": unittest.main() - From d488cd04981482643710d85c389ae70b9ead2098 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 11:47:05 +0800 Subject: [PATCH 0585/1089] more --- test/srt/test_eplb.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index 23e133ad1ec..d63588488fc 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -4,6 +4,7 @@ from pathlib import Path import sglang as sgl +import torch from sglang.srt.managers.expert_distribution_storage import ExpertDistributionStorage from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( @@ -124,9 +125,11 @@ def test_nontrivial_location(self): log_level="info", ) - init_expert_location = dict( - physical_to_logical_map=TODO, + physical_to_logical_map = ( + torch.arange(0, num_physical_experts).repeat(num_layers, 1) + % num_logical_experts ) + init_expert_location = dict(physical_to_logical_map=physical_to_logical_map.tolist()) engine = sgl.Engine(**engine_kwargs, init_expert_location=json.dumps(init_expert_location)) self._assert_behavior(engine, TODO) From 4275aaa85d2951cc790b89685c1c0606be79e69b Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 11:48:31 +0800 Subject: [PATCH 0586/1089] more --- test/srt/test_eplb.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index d63588488fc..5b61ebc3001 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -15,7 +15,9 @@ popen_launch_server, ) -_NUM_ROUTED_EXPERTS = 64 # DeepSeek-Coder-V2-Lite-Instruct +# DeepSeek-Coder-V2-Lite-Instruct +_NUM_ROUTED_EXPERTS = 64 +_NUM_HIDDEN_LAYERS = 27 _REF_OUTPUT = [', 4+4=8,', ', four plus four is eight, eight'] @@ -114,10 +116,11 @@ def _tempdisable_test_eplb_init_expert_location_and_save_expert_distribution(sel del engine def test_nontrivial_location(self): + ep_num_redundant_experts = 4 engine_kwargs = dict( model_path=DEFAULT_MLA_MODEL_NAME_FOR_TEST, trust_remote_code=True, - ep_num_redundant_experts=4, + ep_num_redundant_experts=ep_num_redundant_experts, enable_deepep_moe=True, deepep_mode="normal", disable_cuda_graph=True, @@ -126,8 +129,8 @@ def test_nontrivial_location(self): ) physical_to_logical_map = ( - torch.arange(0, num_physical_experts).repeat(num_layers, 1) - % num_logical_experts + torch.arange(0, _NUM_ROUTED_EXPERTS + ep_num_redundant_experts).repeat(_NUM_HIDDEN_LAYERS, 1) + % _NUM_ROUTED_EXPERTS ) init_expert_location = dict(physical_to_logical_map=physical_to_logical_map.tolist()) From b5076c90be4867b14c7164cf34fd48a0c7677103 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 11:48:44 +0800 Subject: [PATCH 0587/1089] more --- test/srt/test_eplb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index 5b61ebc3001..634f8995fe1 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -135,7 +135,7 @@ def test_nontrivial_location(self): init_expert_location = dict(physical_to_logical_map=physical_to_logical_map.tolist()) engine = sgl.Engine(**engine_kwargs, init_expert_location=json.dumps(init_expert_location)) - self._assert_behavior(engine, TODO) + self._assert_behavior(engine, physical_to_logical_map[0]) engine.shutdown() del engine From 4af2a0c58f18935bb63ac9f2d0ba1110ae8f9065 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 11:49:24 +0800 Subject: [PATCH 0588/1089] more --- test/srt/test_eplb.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index 634f8995fe1..a3e01a51368 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -115,7 +115,7 @@ def _tempdisable_test_eplb_init_expert_location_and_save_expert_distribution(sel engine.shutdown() del engine - def test_nontrivial_location(self): + def _tempdisable_test_nontrivial_location(self): ep_num_redundant_experts = 4 engine_kwargs = dict( model_path=DEFAULT_MLA_MODEL_NAME_FOR_TEST, @@ -139,6 +139,23 @@ def test_nontrivial_location(self): engine.shutdown() del engine + def test_trivial(self): + engine_kwargs = dict( + model_path=DEFAULT_MLA_MODEL_NAME_FOR_TEST, + trust_remote_code=True, + ep_num_redundant_experts=0, + enable_deepep_moe=True, + deepep_mode="normal", + disable_cuda_graph=True, + tp_size=2, + log_level="info", + ) + + engine = sgl.Engine(**engine_kwargs) + self._assert_behavior(engine, "equal_trivial") + engine.shutdown() + del engine + def _assert_behavior( self, engine: sgl.Engine, expect_physical_to_local_map ): From c4cc082b587614d5ab69c6033c7072ea8713b045 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 11:50:49 +0800 Subject: [PATCH 0589/1089] more --- python/sglang/srt/managers/expert_location.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 06f5e314c40..792dd8f6f04 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -199,7 +199,7 @@ def _compute_logical_to_all_physical_map(physical_to_logical_map: torch.Tensor, def _pad_nested_array(arr, pad_value): max_len = max(len(inner) for outer in arr for inner in outer) padded = [[[inner + [pad_value] * (max_len - len(inner))] for inner in outer] for outer in arr] - return torch.tensor(padded) + return padded def _compute_logical_to_rank_dispatch_physical_map( From c879ed6063841a10e0bd5e089d36ab5b290f4f00 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 11:52:39 +0800 Subject: [PATCH 0590/1089] more --- python/sglang/srt/managers/expert_location.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 792dd8f6f04..71747fec177 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -198,7 +198,7 @@ def _compute_logical_to_all_physical_map(physical_to_logical_map: torch.Tensor, def _pad_nested_array(arr, pad_value): max_len = max(len(inner) for outer in arr for inner in outer) - padded = [[[inner + [pad_value] * (max_len - len(inner))] for inner in outer] for outer in arr] + padded = [[inner + [pad_value] * (max_len - len(inner)) for inner in outer] for outer in arr] return padded From 50f4fc91f47f9e0a0d6224e4e09802a3821a036c Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 11:54:03 +0800 Subject: [PATCH 0591/1089] more --- test/srt/test_eplb.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index a3e01a51368..0dc7eb442e8 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -128,8 +128,9 @@ def _tempdisable_test_nontrivial_location(self): log_level="info", ) + offset = 3 physical_to_logical_map = ( - torch.arange(0, _NUM_ROUTED_EXPERTS + ep_num_redundant_experts).repeat(_NUM_HIDDEN_LAYERS, 1) + (offset + torch.arange(0, _NUM_ROUTED_EXPERTS + ep_num_redundant_experts).repeat(_NUM_HIDDEN_LAYERS, 1)) % _NUM_ROUTED_EXPERTS ) init_expert_location = dict(physical_to_logical_map=physical_to_logical_map.tolist()) From 6f0bc859d86a8f84616f609b79b602e43b97ccb0 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 11:54:55 +0800 Subject: [PATCH 0592/1089] more --- test/srt/test_eplb.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index 0dc7eb442e8..3a6a5a4b1b4 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -140,11 +140,11 @@ def _tempdisable_test_nontrivial_location(self): engine.shutdown() del engine - def test_trivial(self): + def test_trivial_with_redundant_experts(self): engine_kwargs = dict( model_path=DEFAULT_MLA_MODEL_NAME_FOR_TEST, trust_remote_code=True, - ep_num_redundant_experts=0, + ep_num_redundant_experts=4, enable_deepep_moe=True, deepep_mode="normal", disable_cuda_graph=True, From 5eb3cfcdd3a489fa52e01f9dc56e0eb24fa5ceae Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 11:57:51 +0800 Subject: [PATCH 0593/1089] more --- test/srt/test_eplb.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index 3a6a5a4b1b4..9592544e481 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -115,7 +115,7 @@ def _tempdisable_test_eplb_init_expert_location_and_save_expert_distribution(sel engine.shutdown() del engine - def _tempdisable_test_nontrivial_location(self): + def test_nontrivial_location(self): ep_num_redundant_experts = 4 engine_kwargs = dict( model_path=DEFAULT_MLA_MODEL_NAME_FOR_TEST, @@ -140,7 +140,7 @@ def _tempdisable_test_nontrivial_location(self): engine.shutdown() del engine - def test_trivial_with_redundant_experts(self): + def _tempdisable_test_trivial_with_redundant_experts(self): engine_kwargs = dict( model_path=DEFAULT_MLA_MODEL_NAME_FOR_TEST, trust_remote_code=True, From 3c55aae652ba47a774a946fb1c374909a9d7cb15 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 16:59:20 +0800 Subject: [PATCH 0594/1089] more --- test/srt/test_eplb.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index 9592544e481..e1926016057 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -23,6 +23,7 @@ class TestEPLB(CustomTestCase): def _tempdisable_test_eplb_e2e(self): + print("Action: test_eplb_e2e") with tempfile.TemporaryDirectory() as tmpdir: engine_kwargs = dict( model_path=DEFAULT_MLA_MODEL_NAME_FOR_TEST, @@ -30,6 +31,7 @@ def _tempdisable_test_eplb_e2e(self): enable_eplb=True, eplb_storage_dir=tmpdir, ep_num_redundant_experts=4, + enable_dp_attention=True, enable_deepep_moe=True, deepep_mode="normal", disable_cuda_graph=True, @@ -66,12 +68,14 @@ def _tempdisable_test_eplb_e2e(self): del engine def _tempdisable_test_eplb_init_expert_location_and_save_expert_distribution(self): + print("Action: test_eplb_init_expert_location_and_save_expert_distribution") with tempfile.TemporaryDirectory() as eplb_storage_dir_a, tempfile.TemporaryDirectory() as eplb_storage_dir_b: engine_kwargs = dict( model_path=DEFAULT_MLA_MODEL_NAME_FOR_TEST, trust_remote_code=True, enable_eplb=True, ep_num_redundant_experts=4, + enable_dp_attention=True, enable_deepep_moe=True, deepep_mode="normal", disable_cuda_graph=True, @@ -116,11 +120,13 @@ def _tempdisable_test_eplb_init_expert_location_and_save_expert_distribution(sel del engine def test_nontrivial_location(self): + print("Action: test_nontrivial_location") ep_num_redundant_experts = 4 engine_kwargs = dict( model_path=DEFAULT_MLA_MODEL_NAME_FOR_TEST, trust_remote_code=True, ep_num_redundant_experts=ep_num_redundant_experts, + enable_dp_attention=True, enable_deepep_moe=True, deepep_mode="normal", disable_cuda_graph=True, @@ -140,11 +146,13 @@ def test_nontrivial_location(self): engine.shutdown() del engine - def _tempdisable_test_trivial_with_redundant_experts(self): + def test_trivial_with_redundant_experts(self): + print("Action: test_trivial_with_redundant_experts") engine_kwargs = dict( model_path=DEFAULT_MLA_MODEL_NAME_FOR_TEST, trust_remote_code=True, ep_num_redundant_experts=4, + enable_dp_attention=True, enable_deepep_moe=True, deepep_mode="normal", disable_cuda_graph=True, From f4f2f3206b22ba2f5296a03ff05106b7d64d3b3a Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 17:00:32 +0800 Subject: [PATCH 0595/1089] more --- test/srt/test_eplb.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index e1926016057..049a70686af 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -38,6 +38,7 @@ def _tempdisable_test_eplb_e2e(self): enable_scheduler_input_blocker=True, disable_overlap_schedule=True, # TODO tp_size=2, + dp_size=2, log_level="info", ) @@ -80,6 +81,7 @@ def _tempdisable_test_eplb_init_expert_location_and_save_expert_distribution(sel deepep_mode="normal", disable_cuda_graph=True, tp_size=2, + dp_size=2, log_level="info", ) @@ -131,6 +133,7 @@ def test_nontrivial_location(self): deepep_mode="normal", disable_cuda_graph=True, tp_size=2, + dp_size=2, log_level="info", ) @@ -157,6 +160,7 @@ def test_trivial_with_redundant_experts(self): deepep_mode="normal", disable_cuda_graph=True, tp_size=2, + dp_size=2, log_level="info", ) From b03c85b80491c06e779386b56db6136c3f88cbd8 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 17:02:15 +0800 Subject: [PATCH 0596/1089] adhoc --- test/srt/test_eplb.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index 049a70686af..4add5b91790 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -18,7 +18,11 @@ # DeepSeek-Coder-V2-Lite-Instruct _NUM_ROUTED_EXPERTS = 64 _NUM_HIDDEN_LAYERS = 27 -_REF_OUTPUT = [', 4+4=8,', ', four plus four is eight, eight'] +# TODO +# TODO temp +# TODO +# _REF_OUTPUT = [', 4+4=8,', ', four plus four is eight, eight'] +_REF_OUTPUT = [', 4+4=8,'] class TestEPLB(CustomTestCase): @@ -196,7 +200,11 @@ def _assert_behavior( def _engine_generate(self, engine: sgl.Engine): output = engine.generate( - prompt=["1+1=2, 2+2=4", "One plus one is two, two plus two is four"], + # TODO + # TODO temp + # TODO + prompt=["1+1=2, 2+2=4"], + # prompt=["1+1=2, 2+2=4", "One plus one is two, two plus two is four"], sampling_params=dict(max_new_tokens=8, temperature=0.0), ) print(f"engine_generate {output=}") From a1b0422f6f1e41a07b9880a2c31e6d8e1ac26230 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 17:08:46 +0800 Subject: [PATCH 0597/1089] temp --- python/sglang/srt/models/deepseek_v2.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index ea17909145b..37ddd2a51eb 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -29,7 +29,7 @@ from sglang.srt.distributed import ( get_tensor_model_parallel_world_size, parallel_state, - tensor_model_parallel_all_reduce, + tensor_model_parallel_all_reduce, get_tensor_model_parallel_rank, ) from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.dp_attention import ( @@ -332,6 +332,8 @@ def forward_deepep( num_expert_group=self.num_expert_group, correction_bias=self.correction_bias, ) + print(f"hi [{get_tensor_model_parallel_rank()}, {self.__class__.__name__}] forward_deepep after-select_experts " + f"{self.layer_id=} {topk_weights=} {topk_idx=} ") if self.ep_size > 1: # TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value ( @@ -348,6 +350,8 @@ def forward_deepep( topk_weights, forward_mode=forward_mode, ) + print(f"hi [{get_tensor_model_parallel_rank()}, {self.__class__.__name__}] forward_deepep after dispatch " + f"{self.layer_id=} {topk_weights=} {topk_idx=} {hidden_states[:, :5]=} ") final_hidden_states = ( self.experts( hidden_states=hidden_states, @@ -360,12 +364,16 @@ def forward_deepep( * self.routed_scaling_factor ) if self.ep_size > 1: + print(f"hi [{get_tensor_model_parallel_rank()}, {self.__class__.__name__}] forward_deepep before combine " + f"{self.layer_id=} {final_hidden_states[:, :5]=} ") final_hidden_states = self.deepep_dispatcher.combine( final_hidden_states, topk_idx, topk_weights, forward_mode, ) + print(f"hi [{get_tensor_model_parallel_rank()}, {self.__class__.__name__}] forward_deepep after combine " + f"{self.layer_id=} {final_hidden_states[:, :5]=} ") if shared_output is not None: final_hidden_states = final_hidden_states + shared_output @@ -1210,6 +1218,8 @@ def forward_deepep( forward_batch: ForwardBatch, residual: Optional[torch.Tensor], ) -> torch.Tensor: + print(f"hi [{get_tensor_model_parallel_rank()}, {self.__class__.__name__}] forward_deepep start {self.layer_id=} {self.mlp.__class__.__name__=} " + f"{hidden_states.shape=} {hidden_states[:1, :5]=} {residual[:1, :5] if residual is not None else None=}") if hidden_states.shape[0] == 0: residual = hidden_states @@ -1272,6 +1282,8 @@ def forward_deepep( list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states ) + print(f"hi [{get_tensor_model_parallel_rank()}, {self.__class__.__name__}] forward_deepep end {self.layer_id=} {self.mlp.__class__.__name__=} " + f"{hidden_states.shape=} {hidden_states[:1, :5]=} {residual[:1, :5] if residual is not None else None=}") return hidden_states, residual @@ -1391,6 +1403,7 @@ def forward( forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, ) -> torch.Tensor: + print(f"hi [{get_tensor_model_parallel_rank()}, {self.__class__.__name__}] forward start {input_ids=}") hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) From 81db4402f0bec9597641fd4742e8e358f186354a Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 17:09:27 +0800 Subject: [PATCH 0598/1089] more --- python/sglang/srt/models/deepseek_v2.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 37ddd2a51eb..3f2ba43fa41 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -73,7 +73,7 @@ ExpertLocationMetadata, ModelConfigForExpertLocation, ) -from sglang.srt.managers.schedule_batch import global_server_args_dict +from sglang.srt.managers.schedule_batch import global_server_args_dict, get_global_expert_location_metadata from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import DeepEPMode, add_prefix, is_cuda, is_hip @@ -183,6 +183,7 @@ def __init__( else 0 ) self.layer_id = layer_id + self.tp_rank = get_tensor_model_parallel_rank() self.routed_scaling_factor = config.routed_scaling_factor if self.tp_size > config.n_routed_experts: @@ -331,6 +332,7 @@ def forward_deepep( topk_group=self.topk_group, num_expert_group=self.num_expert_group, correction_bias=self.correction_bias, + expert_logical_to_rank_dispatch_physical_map=get_global_expert_location_metadata().logical_to_rank_dispatch_physical_map[self.tp_rank, self.layer_id, :], ) print(f"hi [{get_tensor_model_parallel_rank()}, {self.__class__.__name__}] forward_deepep after-select_experts " f"{self.layer_id=} {topk_weights=} {topk_idx=} ") From 77a986b8f1e43eaf9672e07d1f1723f3fdc5241f Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 17:12:17 +0800 Subject: [PATCH 0599/1089] more --- test/srt/test_eplb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index 4add5b91790..391387666d4 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -145,7 +145,7 @@ def test_nontrivial_location(self): physical_to_logical_map = ( (offset + torch.arange(0, _NUM_ROUTED_EXPERTS + ep_num_redundant_experts).repeat(_NUM_HIDDEN_LAYERS, 1)) % _NUM_ROUTED_EXPERTS - ) + ).tolist() init_expert_location = dict(physical_to_logical_map=physical_to_logical_map.tolist()) engine = sgl.Engine(**engine_kwargs, init_expert_location=json.dumps(init_expert_location)) From 7c968d4cae48b8c1dcb1cd2210beb9ebbc36f194 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 17:12:38 +0800 Subject: [PATCH 0600/1089] temp --- python/sglang/srt/models/deepseek_v2.py | 26 ++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 3f2ba43fa41..33f51e528e3 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -334,8 +334,8 @@ def forward_deepep( correction_bias=self.correction_bias, expert_logical_to_rank_dispatch_physical_map=get_global_expert_location_metadata().logical_to_rank_dispatch_physical_map[self.tp_rank, self.layer_id, :], ) - print(f"hi [{get_tensor_model_parallel_rank()}, {self.__class__.__name__}] forward_deepep after-select_experts " - f"{self.layer_id=} {topk_weights=} {topk_idx=} ") + # print(f"hi [{get_tensor_model_parallel_rank()}, {self.__class__.__name__}] forward_deepep after-select_experts " + # f"{self.layer_id=} {topk_weights=} {topk_idx=} ") if self.ep_size > 1: # TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value ( @@ -352,8 +352,8 @@ def forward_deepep( topk_weights, forward_mode=forward_mode, ) - print(f"hi [{get_tensor_model_parallel_rank()}, {self.__class__.__name__}] forward_deepep after dispatch " - f"{self.layer_id=} {topk_weights=} {topk_idx=} {hidden_states[:, :5]=} ") + # print(f"hi [{get_tensor_model_parallel_rank()}, {self.__class__.__name__}] forward_deepep after dispatch " + # f"{self.layer_id=} {topk_weights=} {topk_idx=} {hidden_states[:, :5]=} ") final_hidden_states = ( self.experts( hidden_states=hidden_states, @@ -366,16 +366,16 @@ def forward_deepep( * self.routed_scaling_factor ) if self.ep_size > 1: - print(f"hi [{get_tensor_model_parallel_rank()}, {self.__class__.__name__}] forward_deepep before combine " - f"{self.layer_id=} {final_hidden_states[:, :5]=} ") + # print(f"hi [{get_tensor_model_parallel_rank()}, {self.__class__.__name__}] forward_deepep before combine " + # f"{self.layer_id=} {final_hidden_states[:, :5]=} ") final_hidden_states = self.deepep_dispatcher.combine( final_hidden_states, topk_idx, topk_weights, forward_mode, ) - print(f"hi [{get_tensor_model_parallel_rank()}, {self.__class__.__name__}] forward_deepep after combine " - f"{self.layer_id=} {final_hidden_states[:, :5]=} ") + # print(f"hi [{get_tensor_model_parallel_rank()}, {self.__class__.__name__}] forward_deepep after combine " + # f"{self.layer_id=} {final_hidden_states[:, :5]=} ") if shared_output is not None: final_hidden_states = final_hidden_states + shared_output @@ -1220,8 +1220,8 @@ def forward_deepep( forward_batch: ForwardBatch, residual: Optional[torch.Tensor], ) -> torch.Tensor: - print(f"hi [{get_tensor_model_parallel_rank()}, {self.__class__.__name__}] forward_deepep start {self.layer_id=} {self.mlp.__class__.__name__=} " - f"{hidden_states.shape=} {hidden_states[:1, :5]=} {residual[:1, :5] if residual is not None else None=}") + # print(f"hi [{get_tensor_model_parallel_rank()}, {self.__class__.__name__}] forward_deepep start {self.layer_id=} {self.mlp.__class__.__name__=} " + # f"{hidden_states.shape=} {hidden_states[:1, :5]=} {residual[:1, :5] if residual is not None else None=}") if hidden_states.shape[0] == 0: residual = hidden_states @@ -1284,8 +1284,8 @@ def forward_deepep( list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states ) - print(f"hi [{get_tensor_model_parallel_rank()}, {self.__class__.__name__}] forward_deepep end {self.layer_id=} {self.mlp.__class__.__name__=} " - f"{hidden_states.shape=} {hidden_states[:1, :5]=} {residual[:1, :5] if residual is not None else None=}") + # print(f"hi [{get_tensor_model_parallel_rank()}, {self.__class__.__name__}] forward_deepep end {self.layer_id=} {self.mlp.__class__.__name__=} " + # f"{hidden_states.shape=} {hidden_states[:1, :5]=} {residual[:1, :5] if residual is not None else None=}") return hidden_states, residual @@ -1405,7 +1405,7 @@ def forward( forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, ) -> torch.Tensor: - print(f"hi [{get_tensor_model_parallel_rank()}, {self.__class__.__name__}] forward start {input_ids=}") + # print(f"hi [{get_tensor_model_parallel_rank()}, {self.__class__.__name__}] forward start {input_ids=}") hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) From 7b72597deaa365f0405a9dca3ec464275c26090e Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 17:15:18 +0800 Subject: [PATCH 0601/1089] more --- test/srt/test_eplb.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index 391387666d4..3294a1b825c 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -145,11 +145,11 @@ def test_nontrivial_location(self): physical_to_logical_map = ( (offset + torch.arange(0, _NUM_ROUTED_EXPERTS + ep_num_redundant_experts).repeat(_NUM_HIDDEN_LAYERS, 1)) % _NUM_ROUTED_EXPERTS - ).tolist() + ) init_expert_location = dict(physical_to_logical_map=physical_to_logical_map.tolist()) engine = sgl.Engine(**engine_kwargs, init_expert_location=json.dumps(init_expert_location)) - self._assert_behavior(engine, physical_to_logical_map[0]) + self._assert_behavior(engine, physical_to_logical_map[0].tolist()) engine.shutdown() del engine From d7dc8cf0666d758641c3b7e12673431a355136c5 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 17:17:09 +0800 Subject: [PATCH 0602/1089] more --- test/srt/test_eplb.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index 3294a1b825c..45c16a2e7c2 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -26,7 +26,7 @@ class TestEPLB(CustomTestCase): - def _tempdisable_test_eplb_e2e(self): + def test_eplb_e2e(self): print("Action: test_eplb_e2e") with tempfile.TemporaryDirectory() as tmpdir: engine_kwargs = dict( @@ -72,7 +72,7 @@ def _tempdisable_test_eplb_e2e(self): engine.shutdown() del engine - def _tempdisable_test_eplb_init_expert_location_and_save_expert_distribution(self): + def test_eplb_init_expert_location_and_save_expert_distribution(self): print("Action: test_eplb_init_expert_location_and_save_expert_distribution") with tempfile.TemporaryDirectory() as eplb_storage_dir_a, tempfile.TemporaryDirectory() as eplb_storage_dir_b: engine_kwargs = dict( From 61413cbc84b6d439fd593239848da610cf8d9a6d Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 17:17:40 +0800 Subject: [PATCH 0603/1089] temp --- test/srt/test_eplb.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index 45c16a2e7c2..df1cf941d25 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -18,11 +18,7 @@ # DeepSeek-Coder-V2-Lite-Instruct _NUM_ROUTED_EXPERTS = 64 _NUM_HIDDEN_LAYERS = 27 -# TODO -# TODO temp -# TODO -# _REF_OUTPUT = [', 4+4=8,', ', four plus four is eight, eight'] -_REF_OUTPUT = [', 4+4=8,'] +_REF_OUTPUT = [', 4+4=8,', ', four plus four is eight, eight'] class TestEPLB(CustomTestCase): @@ -200,11 +196,7 @@ def _assert_behavior( def _engine_generate(self, engine: sgl.Engine): output = engine.generate( - # TODO - # TODO temp - # TODO - prompt=["1+1=2, 2+2=4"], - # prompt=["1+1=2, 2+2=4", "One plus one is two, two plus two is four"], + prompt=["1+1=2, 2+2=4", "One plus one is two, two plus two is four"], sampling_params=dict(max_new_tokens=8, temperature=0.0), ) print(f"engine_generate {output=}") From 3b16ccdccfaf492b257fc15e01d7245795626a37 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 17:23:09 +0800 Subject: [PATCH 0604/1089] more --- test/srt/test_eplb.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index df1cf941d25..9378000f922 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -58,7 +58,7 @@ def test_eplb_e2e(self): del engine print(f"Action: start engine") - engine = sgl.Engine(**engine_kwargs) + engine = sgl.Engine(**engine_kwargs, port=21000) self._assert_behavior( engine, physical_to_logical_map_layer_0_after_first_rebalance, @@ -106,6 +106,7 @@ def test_eplb_init_expert_location_and_save_expert_distribution(self): **engine_kwargs, eplb_storage_dir=eplb_storage_dir_b, init_expert_location=str(snapshot_path), + port=21000, ) self._assert_behavior(engine, "not_equal_trivial") print(f"Action: shutdown engine") @@ -115,7 +116,7 @@ def test_eplb_init_expert_location_and_save_expert_distribution(self): print( f"Action: start engine to check automatically loading from storage dir" ) - engine = sgl.Engine(**engine_kwargs, eplb_storage_dir=eplb_storage_dir_a) + engine = sgl.Engine(**engine_kwargs, eplb_storage_dir=eplb_storage_dir_a, port=22000) self._assert_behavior(engine, "not_equal_trivial") print(f"Action: shutdown engine") engine.shutdown() From 54ccacae80df762835d59323c4b2068c3ced65fa Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 17:27:54 +0800 Subject: [PATCH 0605/1089] motr --- python/sglang/srt/managers/expert_distribution.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index fb31d415d76..9ffa870559f 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -83,7 +83,7 @@ def _reset(self): """Reset the expert distribution recorder.""" logger.info("Resetting ExpertDistributionRecorder...") self._recording = False - assert self._current_layer_idx.value is None + assert self._current_layer_idx.value is None, f"{self._current_layer_idx.value=}" for gatherer in self._single_pass_gatherers.values(): gatherer.reset() self._accumulator.reset() From 229d8cc811f5d3f501216ad266f340a080d173d0 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 17:28:29 +0800 Subject: [PATCH 0606/1089] more --- python/sglang/srt/managers/expert_distribution.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 9ffa870559f..20af0fbd4c1 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -83,7 +83,6 @@ def _reset(self): """Reset the expert distribution recorder.""" logger.info("Resetting ExpertDistributionRecorder...") self._recording = False - assert self._current_layer_idx.value is None, f"{self._current_layer_idx.value=}" for gatherer in self._single_pass_gatherers.values(): gatherer.reset() self._accumulator.reset() From a63c7482754fd1f10c5fb7c1be3c4a0dc6e0d506 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 17:28:53 +0800 Subject: [PATCH 0607/1089] more --- python/sglang/srt/managers/expert_distribution.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 20af0fbd4c1..327f7bb95a9 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -82,7 +82,6 @@ def _on_hook(self, hook_name: str, **kwargs): def _reset(self): """Reset the expert distribution recorder.""" logger.info("Resetting ExpertDistributionRecorder...") - self._recording = False for gatherer in self._single_pass_gatherers.values(): gatherer.reset() self._accumulator.reset() From c3a5d6004c3a8d14059fb32889089ddc5c9d3a26 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 17:30:19 +0800 Subject: [PATCH 0608/1089] Revert "more" This reverts commit 229d8cc811f5d3f501216ad266f340a080d173d0. # Conflicts: # python/sglang/srt/managers/expert_distribution.py --- python/sglang/srt/managers/expert_distribution.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 327f7bb95a9..27634d8b78a 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -82,6 +82,7 @@ def _on_hook(self, hook_name: str, **kwargs): def _reset(self): """Reset the expert distribution recorder.""" logger.info("Resetting ExpertDistributionRecorder...") + assert self._current_layer_idx.value is None, f"{self._current_layer_idx.value=}" for gatherer in self._single_pass_gatherers.values(): gatherer.reset() self._accumulator.reset() From 0db27725392c908eda504882972a719fb7b6a33a Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 17:33:45 +0800 Subject: [PATCH 0609/1089] temp --- python/sglang/srt/managers/expert_distribution.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 27634d8b78a..868a061987e 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -1,4 +1,5 @@ import logging +import threading from abc import ABC from contextlib import contextmanager from copy import deepcopy @@ -6,6 +7,7 @@ import torch +from sglang.srt.distributed import get_tensor_model_parallel_rank from sglang.srt.managers.expert_location import ExpertLocationMetadata from sglang.srt.server_args import ServerArgs from sglang.srt.utils import Withable, get_bool_env_var @@ -37,8 +39,12 @@ def initialize( for k in self._accumulator.get_single_pass_gatherer_keys() } + @contextmanager def with_current_layer(self, layer_idx): - return self._current_layer_idx.with_value(layer_idx) + print(f"hi [{get_tensor_model_parallel_rank()}, {threading.get_native_id()}, {self.__class__.__name__}] with_current_layer start {layer_idx=}") + with self._current_layer_idx.with_value(layer_idx): + yield + print(f"hi [{get_tensor_model_parallel_rank()}, {threading.get_native_id()}, {self.__class__.__name__}] with_current_layer end {layer_idx=}") def with_debug_name(self, debug_name): return self._current_debug_name.with_value(debug_name) @@ -106,6 +112,7 @@ def stop_record(self): def dump_record(self): """Dump the expert distribution record and reset the recorder after dumping.""" + print(f"hi [{get_tensor_model_parallel_rank()}, {threading.get_native_id()}, {self.__class__.__name__}] dump_record start") output = self._accumulator.dump() self._reset() return output From f51d7228710b0ac888fbf800956c650ae2a48b93 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 17:55:29 +0800 Subject: [PATCH 0610/1089] more --- python/sglang/srt/managers/expert_distribution.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 868a061987e..cf89e269cbd 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -32,6 +32,7 @@ def initialize( expert_location_metadata: "ExpertLocationMetadata", rank: int, ): + assert server_args.disable_overlap_schedule, "ExpertDistributionRecorder needs disable_overlap_schedule currently (will implement this later)" self._expert_location_metadata = expert_location_metadata self._accumulator = _Accumulator.init_new(expert_location_metadata, rank) self._single_pass_gatherers = { From 46001c934e8a5aba0046b3d18243a7f472c77eda Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 17:55:42 +0800 Subject: [PATCH 0611/1089] more --- test/srt/test_eplb.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index 9378000f922..34318a796a0 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -80,6 +80,7 @@ def test_eplb_init_expert_location_and_save_expert_distribution(self): enable_deepep_moe=True, deepep_mode="normal", disable_cuda_graph=True, + disable_overlap_schedule=True, # TODO tp_size=2, dp_size=2, log_level="info", From 3690b1e1a1439e5154549a49610cf43a39580573 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 17:56:18 +0800 Subject: [PATCH 0612/1089] more --- python/sglang/srt/managers/expert_distribution.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index cf89e269cbd..9957761d4eb 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -32,7 +32,7 @@ def initialize( expert_location_metadata: "ExpertLocationMetadata", rank: int, ): - assert server_args.disable_overlap_schedule, "ExpertDistributionRecorder needs disable_overlap_schedule currently (will implement this later)" + self._server_args = server_args self._expert_location_metadata = expert_location_metadata self._accumulator = _Accumulator.init_new(expert_location_metadata, rank) self._single_pass_gatherers = { @@ -100,6 +100,7 @@ def start_record(self): logger.warning( "SGLang server is already recording expert ids. Did you forget to dump the expert ids recorded so far by sending requests to the `/stop_expert_distribution_record` and `/dump_expert_distribution_record` endpoints?" ) + assert self._server_args.disable_overlap_schedule, "ExpertDistributionRecorder needs disable_overlap_schedule currently (will implement this later)" self._reset() self._recording = True From 5b0a54d7690f6309e648f9772618a189a4678af1 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 11 Apr 2025 17:57:55 +0800 Subject: [PATCH 0613/1089] Revert "temp" This reverts commit 0db27725392c908eda504882972a719fb7b6a33a. --- python/sglang/srt/managers/expert_distribution.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 9957761d4eb..a53b0d29cc9 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -1,5 +1,4 @@ import logging -import threading from abc import ABC from contextlib import contextmanager from copy import deepcopy @@ -7,7 +6,6 @@ import torch -from sglang.srt.distributed import get_tensor_model_parallel_rank from sglang.srt.managers.expert_location import ExpertLocationMetadata from sglang.srt.server_args import ServerArgs from sglang.srt.utils import Withable, get_bool_env_var @@ -40,12 +38,8 @@ def initialize( for k in self._accumulator.get_single_pass_gatherer_keys() } - @contextmanager def with_current_layer(self, layer_idx): - print(f"hi [{get_tensor_model_parallel_rank()}, {threading.get_native_id()}, {self.__class__.__name__}] with_current_layer start {layer_idx=}") - with self._current_layer_idx.with_value(layer_idx): - yield - print(f"hi [{get_tensor_model_parallel_rank()}, {threading.get_native_id()}, {self.__class__.__name__}] with_current_layer end {layer_idx=}") + return self._current_layer_idx.with_value(layer_idx) def with_debug_name(self, debug_name): return self._current_debug_name.with_value(debug_name) @@ -114,7 +108,6 @@ def stop_record(self): def dump_record(self): """Dump the expert distribution record and reset the recorder after dumping.""" - print(f"hi [{get_tensor_model_parallel_rank()}, {threading.get_native_id()}, {self.__class__.__name__}] dump_record start") output = self._accumulator.dump() self._reset() return output From 9ab1f2a066138b8ba5f60341d640b822875ad612 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 5 Apr 2025 22:52:09 +0800 Subject: [PATCH 0614/1089] rebase to master --- .../sglang/srt/layers/moe/ep_moe/kernels.py | 12 +- python/sglang/srt/layers/moe/ep_moe/layer.py | 48 +- .../srt/layers/moe/ep_moe/token_dispatcher.py | 5 +- python/sglang/srt/models/deepseek_v2.py | 811 ++++++++++++------ python/sglang/srt/utils.py | 48 ++ 5 files changed, 632 insertions(+), 292 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/kernels.py b/python/sglang/srt/layers/moe/ep_moe/kernels.py index 3ea6b4b2f68..87b9bea52bf 100644 --- a/python/sglang/srt/layers/moe/ep_moe/kernels.py +++ b/python/sglang/srt/layers/moe/ep_moe/kernels.py @@ -7,7 +7,7 @@ from sglang.srt.distributed import get_tensor_model_parallel_rank from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8 -from sglang.srt.utils import is_cuda +from sglang.srt.utils import DisposibleBox, is_cuda _is_cuda = is_cuda() if _is_cuda: @@ -655,7 +655,12 @@ def grouped_gemm_triton( assert len(block_shape) == 2 block_n, block_k = block_shape[0], block_shape[1] if _is_cuda: - a, scale_a = sglang_per_token_group_quant_fp8(a, block_k) + if isinstance(a, DisposibleBox): + a_box = a + a, scale_a = sglang_per_token_group_quant_fp8(a.value, block_k) + a_box.dispose() + else: + a, scale_a = sglang_per_token_group_quant_fp8(a, block_k) else: a, scale_a = per_token_group_quant_fp8(a, block_k) @@ -676,6 +681,9 @@ def grouped_gemm_triton( m_num_tiles_indptr, seg_indptr, batch_size, config["BLOCK_SIZE_M"] ) + if not isinstance(c, torch.Tensor): + c = c() + grid = lambda META: ( triton.cdiv(a.size(0), META["BLOCK_SIZE_M"]) + batch_size, triton.cdiv(b.size(1), META["BLOCK_SIZE_N"]), diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index dfecb63d940..44df6b665eb 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -38,7 +38,13 @@ ) from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod from sglang.srt.model_executor.forward_batch_info import ForwardMode -from sglang.srt.utils import DeepEPMode, is_cuda, is_hip, set_weight_attrs +from sglang.srt.utils import ( + DeepEPMode, + DisposibleBox, + is_cuda, + is_hip, + set_weight_attrs, +) _is_cuda = is_cuda() @@ -867,7 +873,7 @@ def forward( def forward_normal( self, - hidden_states: torch.Tensor, + hidden_states: DisposibleBox, reorder_topk_ids: torch.Tensor, seg_indptr: torch.Tensor, ): @@ -875,12 +881,16 @@ def forward_normal( assert self.activation == "silu" if self.grouped_gemm_runner is None: self.grouped_gemm_runner = GroupedGemmRunner( - hidden_states.device, use_flashinfer=False # TODO: use flashinfer + hidden_states.value.device, use_flashinfer=False # TODO: use flashinfer ) + hidden_states_device = hidden_states.value.device + hidden_states_shape = hidden_states.value.shape + hidden_states_dtype = hidden_states.value.dtype + if self.activation_scheme == "dynamic" and not self.use_block_quant: max_value = ( - torch.max(hidden_states) + torch.max(hidden_states.value) .repeat(self.num_experts_per_partition) .to(torch.float32) ) @@ -888,23 +898,24 @@ def forward_normal( weight_indices_cur_rank = torch.arange( 0, self.num_experts_per_partition, - device=hidden_states.device, + device=hidden_states_device, dtype=torch.int64, ) # GroupGemm-0 - gateup_output = torch.empty( - hidden_states.shape[0], + gateup_output_creator = lambda: torch.empty( + hidden_states_shape[0], self.w13_weight.shape[1], - device=hidden_states.device, - dtype=hidden_states.dtype, + device=hidden_states_device, + dtype=hidden_states_dtype, ) - if hidden_states.shape[0] > 0: + if hidden_states.value.shape[0] > 0: gateup_output = self.grouped_gemm_runner( + # NOTE pass in box a=hidden_states, b=self.w13_weight, - c=gateup_output, + c=gateup_output_creator, batch_size=self.num_experts_per_partition, weight_column_major=True, seg_indptr=seg_indptr, @@ -918,6 +929,11 @@ def forward_normal( ), block_shape=self.block_shape, ) + else: + gateup_output = gateup_output_creator() + + # NOTE disposed earlier + # hidden_states.dispose() # Act down_input = torch.empty( @@ -927,14 +943,14 @@ def forward_normal( dtype=( self.fp8_dtype if (self.use_fp8_w8a8 and not self.use_block_quant) - else hidden_states.dtype + else hidden_states_dtype ), ) if self.w2_input_scale is None and not self.use_block_quant: self.w2_input_scale = torch.ones( self.num_experts_per_partition, dtype=torch.float32, - device=hidden_states.device, + device=hidden_states_device, ) if self.activation == "silu": @@ -951,12 +967,14 @@ def forward_normal( else: raise ValueError(f"Unsupported activation: {self.activation=}") + del gateup_output + # GroupGemm-1 down_output = torch.empty( down_input.shape[0], self.w2_weight.shape[1], - device=hidden_states.device, - dtype=hidden_states.dtype, + device=hidden_states_device, + dtype=hidden_states_dtype, ) if down_input.shape[0] > 0: down_output = self.grouped_gemm_runner( diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index 1f27b339337..3880dca2c3d 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -1,4 +1,4 @@ -from sglang.srt.utils import DeepEPMode +from sglang.srt.utils import DeepEPMode, DisposibleBox try: from deep_ep import Buffer @@ -205,6 +205,7 @@ def dispatch_b(self, hidden_states, topk_idx, topk_weights, previous_event): seg_indptr = torch.zeros( (self.num_experts + 1,), device=hidden_states.device, dtype=torch.int64 ) + hidden_states = DisposibleBox(hidden_states) masked_m = expected_m = None @@ -308,7 +309,7 @@ def _deepep_permute( hidden_states.shape[1], BLOCK_SIZE=512, ) - return reorder_topk_ids, seg_indptr, gateup_input + return reorder_topk_ids, seg_indptr, DisposibleBox(gateup_input) def combine_a( self, diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 6a960a37150..ccc9ed0619b 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -16,22 +16,28 @@ # https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py """Inference-only DeepseekV2 model.""" -import logging import os +from dataclasses import dataclass +from enum import Enum, auto +from functools import partial from typing import Any, Dict, Iterable, Optional, Tuple import torch import torch.nn.functional as F from torch import nn -from tqdm import tqdm from transformers import PretrainedConfig +from sglang.srt import two_batch_overlap from sglang.srt.distributed import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, parallel_state, tensor_model_parallel_all_reduce, ) from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import ( + decode_attention_fwd_grouped_rope, +) from sglang.srt.layers.dp_attention import ( dp_gather_partial, dp_scatter, @@ -50,12 +56,12 @@ ) from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE +from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.fp8_utils import ( block_quant_to_tensor_quant, - channel_quant_to_tensor_quant, input_to_float8, normalize_e4m3fn_to_e4m3fnuz, ) @@ -68,19 +74,25 @@ ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder +from sglang.srt.managers.expert_distribution import expert_distribution_recorder +from sglang.srt.managers.expert_location import ExpertLocationMetadata from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_loader.weight_utils import default_weight_loader -from sglang.srt.utils import DeepEPMode, add_prefix, is_cuda, is_hip +from sglang.srt.utils import ( + DeepEPMode, + add_prefix, + configure_deep_gemm_num_sms, + is_cuda, + is_cuda_available, + is_hip, +) _is_hip = is_hip() _is_cuda = is_cuda() if _is_cuda: from sgl_kernel import awq_dequantize, bmm_fp8 - - from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher else: from vllm import _custom_ops as ops @@ -89,10 +101,6 @@ decode_attention_fwd_grouped_rope, ) -expert_distribution_recorder = ExpertDistributionRecorder() - -logger = logging.getLogger(__name__) - class DeepseekV2MLP(nn.Module): def __init__( @@ -133,7 +141,7 @@ def __init__( ) self.act_fn = SiluAndMul() - def forward(self, x): + def forward(self, x, forward_mode: Optional[ForwardMode] = None): gate_up, _ = self.gate_up_proj(x) x = self.act_fn(gate_up) x, _ = self.down_proj(x) @@ -168,19 +176,15 @@ def __init__( self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, + layer_id: int = -999, prefix: str = "", ): super().__init__() self.tp_size = get_tensor_model_parallel_world_size() self.routed_scaling_factor = config.routed_scaling_factor self.n_shared_experts = config.n_shared_experts - self.n_share_experts_fusion = ( - global_server_args_dict["n_share_experts_fusion"] - if global_server_args_dict["n_share_experts_fusion"] is not None - else 0 - ) - self.routed_scaling_factor = config.routed_scaling_factor + self.layer_id = layer_id if self.tp_size > config.n_routed_experts: raise ValueError( f"Tensor parallel size {self.tp_size} is greater than " @@ -200,10 +204,9 @@ def __init__( if global_server_args_dict["enable_deepep_moe"] else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE) ) - self.experts = MoEImpl( - num_experts=config.n_routed_experts + self.n_share_experts_fusion, - top_k=config.num_experts_per_tok + min(self.n_share_experts_fusion, 1), + num_experts=config.n_routed_experts, + top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, renormalize=config.norm_topk_prob, @@ -220,7 +223,7 @@ def __init__( ), ) - if config.n_shared_experts is not None and self.n_share_experts_fusion == 0: + if config.n_shared_experts is not None: intermediate_size = config.moe_intermediate_size * config.n_shared_experts # disable tp for shared experts when enable deepep moe if not global_server_args_dict["enable_deepep_moe"]: @@ -258,18 +261,27 @@ def __init__( else None ) - self.deepep_dispatcher = DeepEPDispatcher( - group=parallel_state.get_tp_group().device_group, - router_topk=self.top_k, - permute_fusion=True, - num_experts=config.n_routed_experts, - num_local_experts=config.n_routed_experts // self.tp_size, - hidden_size=config.hidden_size, - params_dtype=config.torch_dtype, - deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]], - async_finish=True, # TODO - return_recv_hook=True, - ) + self.deepep_dispatcher = self._create_deepep_dispatcher(config) + + if global_server_args_dict["enable_two_batch_overlap"]: + # TODO maybe we do not need to create 2+1 dispatchers, but can reuse the one above + self.tbo_deepep_dispatchers = [ + self._create_deepep_dispatcher(config) for i in range(2) + ] + + def _create_deepep_dispatcher(self, config): + return DeepEPDispatcher( + group=parallel_state.get_tp_group().device_group, + router_topk=self.top_k, + permute_fusion=True, + num_experts=config.n_routed_experts, + num_local_experts=config.n_routed_experts // self.tp_size, + hidden_size=config.hidden_size, + params_dtype=config.torch_dtype, + deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]], + async_finish=True, # TODO + return_recv_hook=True, + ) def forward( self, hidden_states: torch.Tensor, forward_mode: Optional[ForwardMode] = None @@ -280,13 +292,12 @@ def forward( return self.forward_deepep(hidden_states, forward_mode) def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor: - shared_output = self._forward_shared_experts(hidden_states) + if self.n_shared_experts is not None: + shared_output = self.shared_experts(hidden_states) # router_logits: (num_tokens, n_experts) router_logits = self.gate(hidden_states) - final_hidden_states = ( - self.experts(hidden_states=hidden_states, router_logits=router_logits) - * self.routed_scaling_factor - ) + final_hidden_states = self.experts(hidden_states=hidden_states, router_logits=router_logits) + final_hidden_states *= self.routed_scaling_factor if shared_output is not None: final_hidden_states = final_hidden_states + shared_output if self.tp_size > 1: @@ -296,7 +307,66 @@ def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward_deepep( self, hidden_states: torch.Tensor, forward_mode: ForwardMode ) -> torch.Tensor: - shared_output = None + shared_output = self._forward_deepep_shared_output(forward_mode, hidden_states) + + if ( + forward_mode is not None + and not forward_mode.is_idle() + and hidden_states.shape[0] > 0 + ): + # router_logits: (num_tokens, n_experts) + router_logits = self.gate(hidden_states) + else: + router_logits = None + + self._forward_deepep_dispatch_a( + self.deepep_dispatcher, forward_mode, hidden_states, router_logits + ) + ( + hidden_states, + topk_idx, + topk_weights, + reorder_topk_ids, + seg_indptr, + masked_m, + expected_m, + ) = self.deepep_dispatcher.dispatch_b() + + final_hidden_states = self._forward_deepep_expert( + hidden_states=hidden_states, + reorder_topk_ids=reorder_topk_ids, + seg_indptr=seg_indptr, + masked_m=masked_m, + expected_m=expected_m, + forward_mode=forward_mode, + ) + + if self.tp_size > 1: + final_hidden_states = self.deepep_dispatcher.combine( + hidden_states=final_hidden_states, + topk_idx=topk_idx, + topk_weights=topk_weights, + forward_mode=forward_mode, + ) + + if shared_output is not None: + final_hidden_states = final_hidden_states + shared_output + + return final_hidden_states + + def _forward_deepep_shared_output(self, forward_mode, hidden_states): + if ( + forward_mode is not None + and not forward_mode.is_idle() + and hidden_states.shape[0] > 0 + and self.n_shared_experts is not None + ): + return self.shared_experts(hidden_states) + return None + + def _forward_deepep_dispatch_a( + self, chosen_deepep_dispatcher, forward_mode, hidden_states, router_logits + ): topk_idx = torch.full( (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device ) @@ -308,9 +378,6 @@ def forward_deepep( and not forward_mode.is_idle() and hidden_states.shape[0] > 0 ): - # router_logits: (num_tokens, n_experts) - router_logits = self.gate(hidden_states) - shared_output = self._forward_shared_experts(hidden_states) topk_weights, topk_idx = select_experts( hidden_states=hidden_states, router_logits=router_logits, @@ -321,50 +388,101 @@ def forward_deepep( num_expert_group=self.num_expert_group, correction_bias=self.correction_bias, ) - if self.ep_size > 1: - # TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value + chosen_deepep_dispatcher.dispatch_a( + hidden_states, + topk_idx, + topk_weights, + self.num_experts, + forward_mode=forward_mode, + ) + + def _forward_deepep_expert( + self, + hidden_states, + reorder_topk_ids, + seg_indptr, + masked_m, + expected_m, + forward_mode, + ): + output = self.experts( + hidden_states=hidden_states, + reorder_topk_ids=reorder_topk_ids, + seg_indptr=seg_indptr, + masked_m=masked_m, + expected_m=expected_m, + forward_mode=forward_mode, + ) + output *= self.routed_scaling_factor + return output + + # ----------------------------------------- TBO-related -------------------------------------------- + + def _forward_tbo_op_gate(self, state): + state.router_logits = self.gate(state.hidden_states_after_post_attn_ln) + + def _forward_tbo_op_mlp(self, state): + state.expert_output_hidden_states = self._forward_deepep_expert( + hidden_states=state.hidden_states_from_dispatch, + reorder_topk_ids=state.reorder_topk_ids_from_dispatch, + seg_indptr=state.seg_indptr_from_dispatch, + masked_m=state.masked_m_from_dispatch, + expected_m=state.expected_m_from_dispatch, + forward_mode=state.forward_batch.forward_mode, + ) + + def _forward_tbo_op_dispatch_a(self, state): + self._forward_deepep_dispatch_a( + self.tbo_deepep_dispatchers[state.tbo_subbatch_index], + state.forward_batch.forward_mode, + state.hidden_states_after_post_attn_ln, + state.router_logits, + ) + + def _forward_tbo_op_dispatch_b(self, state, tbo_child_index: int): + dispatcher = self.tbo_deepep_dispatchers[state.tbo_subbatch_index] + with expert_distribution_recorder.with_current_layer( + self.layer_id + ), expert_distribution_recorder.with_debug_name( + ["child_a", "child_b"][tbo_child_index] + ): ( - hidden_states, - topk_idx, - topk_weights, - reorder_topk_ids, - seg_indptr, - masked_m, - expected_m, - ) = self.deepep_dispatcher.dispatch( - hidden_states, - topk_idx, - topk_weights, - forward_mode=forward_mode, - ) - final_hidden_states = ( - self.experts( - hidden_states=hidden_states, - reorder_topk_ids=reorder_topk_ids, - seg_indptr=seg_indptr, - masked_m=masked_m, - expected_m=expected_m, - forward_mode=forward_mode, - ) - * self.routed_scaling_factor + state.hidden_states_from_dispatch, + state.topk_idx_from_dispatch, + state.topk_weights_from_dispatch, + state.reorder_topk_ids_from_dispatch, + state.seg_indptr_from_dispatch, + state.masked_m_from_dispatch, + state.expected_m_from_dispatch, + ) = dispatcher.dispatch_b() + + def _forward_tbo_op_combine_a(self, state): + self.tbo_deepep_dispatchers[state.tbo_subbatch_index].combine_a( + hidden_states=state.expert_output_hidden_states, + topk_idx=state.topk_idx_from_dispatch, + topk_weights=state.topk_weights_from_dispatch, + forward_mode=state.forward_batch.forward_mode, ) - if self.ep_size > 1: - final_hidden_states = self.deepep_dispatcher.combine( - final_hidden_states, - topk_idx, - topk_weights, - forward_mode, - ) - if shared_output is not None: - final_hidden_states = final_hidden_states + shared_output - return final_hidden_states + def _forward_tbo_op_combine_b(self, state): + dispatcher = self.tbo_deepep_dispatchers[state.tbo_subbatch_index] + state.hidden_states_from_combine = dispatcher.combine_b() - def _forward_shared_experts(self, hidden_states): - if self.n_shared_experts is not None and self.n_share_experts_fusion == 0: - return self.shared_experts(hidden_states) - else: - return None + def _forward_tbo_op_shared(self, state): + state.shared_output = self._forward_deepep_shared_output( + state.forward_batch.forward_mode, state.hidden_states_after_post_attn_ln + ) + + def _forward_tbo_op_compute_layer_output(self, state): + output = dict( + positions=state.positions, + hidden_states=state.hidden_states_from_combine + state.shared_output, + forward_batch=state.forward_batch, + residual=state.residual_after_post_attn_ln, + tbo_subbatch_index=state.tbo_subbatch_index, + ) + state.clear() + return output def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: @@ -493,7 +611,6 @@ def __init__( self.scaling, num_kv_heads=self.num_local_heads, layer_id=layer_id, - quant_config=quant_config, prefix=add_prefix("attn", prefix), ) @@ -674,7 +791,6 @@ def __init__( num_kv_heads=1, layer_id=layer_id, v_head_dim=self.kv_lora_rank, - quant_config=quant_config, prefix=add_prefix("attn_mqa", prefix), ) @@ -685,7 +801,6 @@ def __init__( num_kv_heads=self.num_local_heads, layer_id=layer_id, v_head_dim=self.v_head_dim, - quant_config=quant_config, prefix=add_prefix("attn_mha", prefix), ) @@ -693,6 +808,7 @@ def __init__( self.w_vc = None self.w_scale = None + self.enable_flashinfer_mla = global_server_args_dict["enable_flashinfer_mla"] self.flashinfer_mla_disable_ragged = global_server_args_dict[ "flashinfer_mla_disable_ragged" ] @@ -700,7 +816,7 @@ def __init__( self.rocm_fused_decode_mla = os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1" def no_absorb(self, forward_batch: ForwardBatch) -> bool: - if self.attention_backend == "flashinfer": + if self.enable_flashinfer_mla: # Flashinfer MLA: Do not absorb when enabling ragged prefill return ( not self.flashinfer_mla_disable_ragged @@ -797,6 +913,17 @@ def forward_absorb( hidden_states: torch.Tensor, forward_batch: ForwardBatch, ) -> torch.Tensor: + state = self.forward_absorb_stage_prepare( + positions, hidden_states, forward_batch + ) + return self.forward_absorb_stage_core(state) + + def forward_absorb_stage_prepare( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + ): q_len = hidden_states.shape[0] q_input = hidden_states.new_empty( q_len, self.num_local_heads, self.kv_lora_rank + self.qk_rope_head_dim @@ -839,6 +966,11 @@ def forward_absorb( q_input[..., self.kv_lora_rank :] = q_pe k_input[..., self.kv_lora_rank :] = k_pe + return q_input, k_input, v_input, forward_batch + + def forward_absorb_stage_core(self, state) -> torch.Tensor: + q_input, k_input, v_input, forward_batch = state + attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch) attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank) @@ -1010,6 +1142,17 @@ def forward_absorb_fused_mla_rope( return output +class _DecoderLayerExecutionMode(Enum): + MLP_ONE = auto() + MLP_ALL = auto() + + +@dataclass +class _DecoderLayerInfo: + is_sparse: bool + execution_mode: _DecoderLayerExecutionMode + + class DeepseekV2DecoderLayer(nn.Module): def __init__( @@ -1020,14 +1163,6 @@ def __init__( is_nextn: bool = False, prefix: str = "", ) -> None: - - def is_sparse_layer(l: int): - return ( - config.n_routed_experts is not None - and l >= config.first_k_dense_replace - and l % config.moe_layer_freq == 0 - ) - super().__init__() self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) @@ -1080,26 +1215,35 @@ def is_sparse_layer(l: int): prefix=add_prefix("self_attn", prefix), ) - if is_nextn or is_sparse_layer(layer_id): + self.info = self._compute_info(config, layer_id=layer_id, is_nextn=is_nextn) + previous_layer_info = self._compute_info( + config, layer_id=layer_id - 1, is_nextn=False + ) + + if self.info.is_sparse: self.mlp = DeepseekV2MoE( config=config, quant_config=quant_config, prefix=add_prefix("mlp", prefix), + layer_id=self.layer_id, ) - self.is_sparse = True else: + if self._enable_moe_dense_fully_dp(): + mlp_tp_rank, mlp_tp_size = 0, 1 + else: + mlp_tp_rank, mlp_tp_size = None, None self.mlp = DeepseekV2MLP( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, prefix=add_prefix("mlp", prefix), + tp_rank=mlp_tp_rank, + tp_size=mlp_tp_size, ) - self.is_sparse = False self.input_is_scattered = ( - is_sparse_layer(layer_id - 1) - and global_server_args_dict["enable_deepep_moe"] + previous_layer_info.execution_mode == _DecoderLayerExecutionMode.MLP_ONE ) self.is_last_layer = self.layer_id == config.num_hidden_layers - 1 @@ -1108,6 +1252,25 @@ def is_sparse_layer(l: int): config.hidden_size, eps=config.rms_norm_eps ) + @staticmethod + def _enable_moe_dense_fully_dp(): + return global_server_args_dict["moe_dense_tp_size"] == 1 + + @staticmethod + def _compute_info(config: PretrainedConfig, layer_id: int, is_nextn: bool): + is_sparse = is_nextn or ( + config.n_routed_experts is not None + and layer_id >= config.first_k_dense_replace + and layer_id % config.moe_layer_freq == 0 + ) + execution_mode = ( + _DecoderLayerExecutionMode.MLP_ONE + if (global_server_args_dict["enable_deepep_moe"] and is_sparse) + or (DeepseekV2DecoderLayer._enable_moe_dense_fully_dp() and not is_sparse) + else _DecoderLayerExecutionMode.MLP_ALL + ) + return _DecoderLayerInfo(is_sparse=is_sparse, execution_mode=execution_mode) + def forward( self, positions: torch.Tensor, @@ -1115,16 +1278,18 @@ def forward( forward_batch: ForwardBatch, residual: Optional[torch.Tensor], ) -> torch.Tensor: - if global_server_args_dict["enable_deepep_moe"] and self.is_sparse: - return self.forward_deepep( + if self.info.execution_mode == _DecoderLayerExecutionMode.MLP_ONE: + return self.forward_mode_mlp_one( positions, hidden_states, forward_batch, residual ) - else: - return self.forward_normal( + elif self.info.execution_mode == _DecoderLayerExecutionMode.MLP_ALL: + return self.forward_mode_mlp_all( positions, hidden_states, forward_batch, residual ) + else: + raise NotImplementedError - def forward_normal( + def forward_mode_mlp_all( self, positions: torch.Tensor, hidden_states: torch.Tensor, @@ -1135,11 +1300,9 @@ def forward_normal( if hidden_states.shape[0] == 0: residual = hidden_states else: - if residual is None: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - else: - hidden_states, residual = self.input_layernorm(hidden_states, residual) + hidden_states, residual = self._forward_input_layernorm( + hidden_states, residual + ) assert not ( self.attn_tp_size != 1 and self.input_is_scattered @@ -1191,7 +1354,7 @@ def forward_normal( return hidden_states, residual - def forward_deepep( + def forward_mode_mlp_one( self, positions: torch.Tensor, hidden_states: torch.Tensor, @@ -1202,11 +1365,9 @@ def forward_deepep( if hidden_states.shape[0] == 0: residual = hidden_states else: - if residual is None: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - else: - hidden_states, residual = self.input_layernorm(hidden_states, residual) + hidden_states, residual = self._forward_input_layernorm( + hidden_states, residual + ) if self.attn_tp_size != 1 and self.input_is_scattered: hidden_states, local_hidden_states = ( @@ -1247,7 +1408,13 @@ def forward_deepep( hidden_states, residual = self.post_attention_layernorm( hidden_states, residual ) - hidden_states = self.mlp(hidden_states, forward_batch.forward_mode) + + if not ( + self._enable_moe_dense_fully_dp() + and (not self.info.is_sparse) + and hidden_states.shape[0] == 0 + ): + hidden_states = self.mlp(hidden_states, forward_batch.forward_mode) if self.is_last_layer and self.attn_tp_size != 1: hidden_states += residual @@ -1262,6 +1429,117 @@ def forward_deepep( return hidden_states, residual + def _forward_input_layernorm(self, hidden_states, residual): + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + return hidden_states, residual + + # ----------------------------------------- TBO-related -------------------------------------------- + + def get_forward_tbo_operations( + self, forward_mode: ForwardMode, tbo_child_index: int + ): + if forward_mode == ForwardMode.EXTEND: + operations = [ + self._forward_tbo_op_input_layernorm, + self._forward_tbo_op_prefill_attn, + self._forward_tbo_op_post_attn_layernorm, + self.mlp._forward_tbo_op_gate, + self.mlp._forward_tbo_op_dispatch_a, + two_batch_overlap.YieldOperation(), + partial( + self.mlp._forward_tbo_op_dispatch_b, tbo_child_index=tbo_child_index + ), + self.mlp._forward_tbo_op_mlp, + self.mlp._forward_tbo_op_combine_a, + two_batch_overlap.YieldOperation(), + self.mlp._forward_tbo_op_shared, + self.mlp._forward_tbo_op_combine_b, + self.mlp._forward_tbo_op_compute_layer_output, + ] + elif forward_mode == ForwardMode.DECODE: + operations = [ + self._forward_tbo_op_input_layernorm, + self._forward_tbo_op_decode_attn_0, + two_batch_overlap.YieldOperation(), + self._forward_tbo_op_decode_attn_1, + self._forward_tbo_op_post_attn_layernorm, + self.mlp._forward_tbo_op_gate, + two_batch_overlap.YieldOperation(), + self.mlp._forward_tbo_op_dispatch_a, + self.mlp._forward_tbo_op_shared, + two_batch_overlap.YieldOperation(), + partial( + self.mlp._forward_tbo_op_dispatch_b, tbo_child_index=tbo_child_index + ), + self.mlp._forward_tbo_op_mlp, + self.mlp._forward_tbo_op_combine_a, + two_batch_overlap.YieldOperation(), + self.mlp._forward_tbo_op_combine_b, + self.mlp._forward_tbo_op_compute_layer_output, + two_batch_overlap.YieldOperation(), + ] + else: + raise NotImplementedError(f"Unsupported {forward_mode=}") + return two_batch_overlap.decorate_operations( + operations, debug_name_prefix=f"L{self.layer_id}-" + ) + + def _forward_tbo_op_input_layernorm( + self, + state, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + residual: Optional[torch.Tensor], + tbo_subbatch_index: int, + ): + state.hidden_states_after_input_ln, state.residual_after_input_ln = ( + self._forward_input_layernorm(hidden_states, residual) + ) + state.update( + dict( + forward_batch=forward_batch, + positions=positions, + tbo_subbatch_index=tbo_subbatch_index, + ) + ) + + def _forward_tbo_op_prefill_attn(self, state): + state.hidden_states_after_attn = self.self_attn( + positions=state.positions, + hidden_states=state.hidden_states_after_input_ln, + forward_batch=state.forward_batch, + ) + + def _forward_tbo_op_decode_attn_0(self, state): + state.self_attn_state = self.self_attn.forward_absorb_stage_prepare( + positions=state.positions, + hidden_states=state.hidden_states_after_input_ln, + forward_batch=state.forward_batch, + ) + + def _forward_tbo_op_decode_attn_1(self, state): + assert ( + (get_tensor_model_parallel_world_size() > 1) + and global_server_args_dict["enable_dp_attention"] + and global_server_args_dict["enable_deepep_moe"] + and isinstance(self.mlp, DeepseekV2MoE) + ) + state.hidden_states_after_attn = self.self_attn.forward_absorb_stage_core( + state.self_attn_state + ) + + def _forward_tbo_op_post_attn_layernorm(self, state): + state.hidden_states_after_post_attn_ln, state.residual_after_post_attn_ln = ( + self.post_attention_layernorm( + state.hidden_states_after_attn, state.residual_after_input_ln + ) + ) + class DeepseekV2Model(nn.Module): fall_back_to_pt_during_load = False @@ -1275,6 +1553,7 @@ def __init__( super().__init__() self.padding_id = config.pad_token_id self.vocab_size = config.vocab_size + self.first_k_dense_replace = config.first_k_dense_replace self.embed_tokens = VocabParallelEmbedding( config.vocab_size, @@ -1310,12 +1589,27 @@ def forward( hidden_states = input_embeds residual = None - for i in range(len(self.layers)): - expert_distribution_recorder.set_current_layer(i) - layer = self.layers[i] - hidden_states, residual = layer( - positions, hidden_states, forward_batch, residual - ) + + normal_num_layers = ( + self.first_k_dense_replace + if forward_batch.can_run_tbo + else len(self.layers) + ) + for i in range(normal_num_layers): + with expert_distribution_recorder.with_current_layer(i): + layer = self.layers[i] + hidden_states, residual = layer( + positions, hidden_states, forward_batch, residual + ) + + hidden_states, residual = self._forward_tbo_layers( + positions=positions, + forward_batch=forward_batch, + hidden_states=hidden_states, + residual=residual, + start_layer=normal_num_layers, + ) + if not forward_batch.forward_mode.is_idle(): if residual is None: hidden_states = self.norm(hidden_states) @@ -1323,6 +1617,47 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) return hidden_states + def _forward_tbo_layers( + self, + positions: torch.Tensor, + forward_batch: ForwardBatch, + hidden_states: torch.Tensor, + residual: torch.Tensor, + start_layer: int, + ): + end_layer = len(self.layers) + if start_layer == end_layer: + return hidden_states, residual + + def compute_operations(tbo_child_index: str): + return [ + op + for i in range(start_layer, end_layer) + for op in self.layers[i].get_forward_tbo_operations( + forward_batch.forward_mode, tbo_child_index + ) + ] + + # TODO do not hardcode + chosen_num_sms = ( + torch.cuda.get_device_properties(device="cuda").multi_processor_count - 20 + ) + with configure_deep_gemm_num_sms(num_sms=chosen_num_sms): + return two_batch_overlap.model_forward_execute_two_batch( + inputs=dict( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + residual=residual, + ), + operations_a=compute_operations(0), + operations_b=compute_operations(1), + delta_stages={ + ForwardMode.EXTEND: 0, + ForwardMode.DECODE: 2, + }[forward_batch.forward_mode], + ) + class DeepseekV2ForCausalLM(nn.Module): @@ -1334,28 +1669,7 @@ def __init__( ) -> None: super().__init__() self.config = config - self.tp_size = get_tensor_model_parallel_world_size() self.quant_config = quant_config - self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"] - # Only Deepseek V3/R1 can use shared experts fusion optimization now. - if ( - global_server_args_dict.get("disable_shared_experts_fusion", False) - or self.config.architectures[0] != "DeepseekV3ForCausalLM" - or self.config.n_routed_experts != 256 - or self.config.routed_scaling_factor != 2.5 - ): - self.n_share_experts_fusion = None - global_server_args_dict["n_share_experts_fusion"] = None - logger.info( - "Only Deepseek V3/R1 can use shared experts fusion optimization. Shared experts fusion optimization is disabled." - ) - elif self.n_share_experts_fusion is None: - global_server_args_dict["n_share_experts_fusion"] = self.tp_size - self.n_share_experts_fusion = self.tp_size - logger.info( - f"Shared experts fusion optimization is default enabled in DeepSeek V3/R1, and n_share_experts_fusion is set to {self.tp_size}. You can tune it by setting --n_share_experts_fusion or disable it by setting --disable_shared_experts_fusion." - ) - self.model = DeepseekV2Model( config, quant_config, prefix=add_prefix("model", prefix) ) @@ -1386,134 +1700,12 @@ def forward( input_ids, hidden_states, self.lm_head, forward_batch ) - def post_load_weights(self): - - # Perform post-processing after loading weights - - if not global_server_args_dict["disable_mla"]: - for layer_id in range(self.config.num_hidden_layers): - self_attn = self.model.layers[layer_id].self_attn - if hasattr(self_attn.kv_b_proj, "qweight"): - # AWQ compatible - if _is_cuda: - w = awq_dequantize( - self_attn.kv_b_proj.qweight, - self_attn.kv_b_proj.scales, - self_attn.kv_b_proj.qzeros, - ).T - else: - w = ops.awq_dequantize( - self_attn.kv_b_proj.qweight, - self_attn.kv_b_proj.scales, - self_attn.kv_b_proj.qzeros, - 0, - 0, - 0, - ).T - else: - w = self_attn.kv_b_proj.weight - # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`. - # This may affect the accuracy of fp8 model. - if w.dtype in ( - torch.float8_e4m3fn, - torch.float8_e4m3fnuz, - ): - if hasattr(self.quant_config, "weight_block_size"): - weight_block_size = self.quant_config.weight_block_size - if weight_block_size is not None: - assert hasattr(self_attn.kv_b_proj, "weight_scale_inv") - if _is_hip: - weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( - weight=w, - weight_scale=self_attn.kv_b_proj.weight_scale_inv, - input_scale=None, - ) - else: - weight = w - weight_scale = self_attn.kv_b_proj.weight_scale_inv - - w, scale = block_quant_to_tensor_quant( - weight, weight_scale, weight_block_size - ) - self_attn.w_scale = scale - else: - weight = w - weight_scale = self_attn.kv_b_proj.weight_scale - w, scale = channel_quant_to_tensor_quant(weight, weight_scale) - self_attn.w_scale = scale - - if w.dtype == torch.int8: - if hasattr(self.quant_config, "weight_block_size"): - # block-wise int8 need it - weight_block_size = self.quant_config.weight_block_size - if weight_block_size is not None: - assert hasattr(self_attn.kv_b_proj, "weight_scale_inv") - weight = w - weight_scale = self_attn.kv_b_proj.weight_scale_inv - w = int8_block_dequant( - weight, weight_scale, weight_block_size - ).to(torch.bfloat16) - else: - # channel-wise int8 need it - w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to( - torch.bfloat16 - ) - w_kc, w_vc = w.unflatten( - 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim) - ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1) - self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2) - self_attn.w_vc = w_vc.contiguous().transpose(1, 2) - if ( - hasattr(self_attn.kv_b_proj, "weight_scale") - and self_attn.w_scale is None - ): - self_attn.w_scale = self_attn.kv_b_proj.weight_scale - if _is_hip: - self_attn.w_scale *= 2.0 - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), ] - if self.n_share_experts_fusion is not None and self.n_share_experts_fusion > 0: - weights_list = list(weights) - weights_dict = dict(weights_list) - suffix_list = [ - "down_proj.weight", - "down_proj.weight_scale_inv", - "gate_proj.weight", - "gate_proj.weight_scale_inv", - "up_proj.weight", - "up_proj.weight_scale_inv", - ] - names_to_remove = [] - for moe_layer in tqdm( - range( - self.config.first_k_dense_replace, - self.config.num_hidden_layers, - self.config.moe_layer_freq, - ), - desc=f"Cloning {self.n_share_experts_fusion} " - "replicas of the shared expert into MoE", - ): - for num_repeat in range(self.n_share_experts_fusion): - for suffix in suffix_list: - shared_expert_weight_name = ( - f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}" - ) - weights_list.append( - ( - f"model.layers.{moe_layer}." - f"mlp.experts." - f"{self.config.n_routed_experts + num_repeat}" - f".{suffix}", - weights_dict[shared_expert_weight_name].clone(), - ) - ) - names_to_remove += [shared_expert_weight_name] - weights = [w for w in weights_list if w[0] not in names_to_remove] # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) @@ -1526,12 +1718,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", - num_experts=self.config.n_routed_experts - + ( - self.n_share_experts_fusion - if self.n_share_experts_fusion is not None - else 0 - ), + num_experts=self.config.n_routed_experts, ) params_dict = dict(self.named_parameters()) @@ -1595,7 +1782,79 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ) weight_loader(param, loaded_weight) - self.post_load_weights() + if not global_server_args_dict["disable_mla"]: + for layer_id in range(self.config.num_hidden_layers): + self_attn = self.model.layers[layer_id].self_attn + if hasattr(self_attn.kv_b_proj, "qweight"): + # AWQ compatible + if _is_cuda: + w = awq_dequantize( + self_attn.kv_b_proj.qweight, + self_attn.kv_b_proj.scales, + self_attn.kv_b_proj.qzeros, + ).T + else: + w = ops.awq_dequantize( + self_attn.kv_b_proj.qweight, + self_attn.kv_b_proj.scales, + self_attn.kv_b_proj.qzeros, + 0, + 0, + 0, + ).T + else: + w = self_attn.kv_b_proj.weight + # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`. + # This may affect the accuracy of fp8 model. + if hasattr(self.quant_config, "weight_block_size") and w.dtype in ( + torch.float8_e4m3fn, + torch.float8_e4m3fnuz, + ): + weight_block_size = self.quant_config.weight_block_size + if weight_block_size is not None: + assert hasattr(self_attn.kv_b_proj, "weight_scale_inv") + if _is_hip: + weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( + weight=w, + weight_scale=self_attn.kv_b_proj.weight_scale_inv, + input_scale=None, + ) + else: + weight = w + weight_scale = self_attn.kv_b_proj.weight_scale_inv + + w, scale = block_quant_to_tensor_quant( + weight, weight_scale, weight_block_size + ) + self_attn.w_scale = scale + if w.dtype == torch.int8: + if hasattr(self.quant_config, "weight_block_size"): + # block-wise int8 need it + weight_block_size = self.quant_config.weight_block_size + if weight_block_size is not None: + assert hasattr(self_attn.kv_b_proj, "weight_scale_inv") + weight = w + weight_scale = self_attn.kv_b_proj.weight_scale_inv + w = int8_block_dequant( + weight, weight_scale, weight_block_size + ).to(torch.bfloat16) + else: + # channel-wise int8 need it + w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to( + torch.bfloat16 + ) + w_kc, w_vc = w.unflatten( + 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim) + ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1) + self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2) + self_attn.w_vc = w_vc.contiguous().transpose(1, 2) + if ( + hasattr(self_attn.kv_b_proj, "weight_scale") + and self_attn.w_scale is None + ): + self_attn.w_scale = self_attn.kv_b_proj.weight_scale + if _is_hip: + self_attn.w_scale *= 2.0 def get_embed_and_head(self): return self.model.embed_tokens.weight, self.lm_head.weight @@ -1608,6 +1867,12 @@ def set_embed_and_head(self, embed, head): torch.cuda.empty_cache() torch.cuda.synchronize() + def get_expert_location_metadata(self): + return ExpertLocationMetadata.init_new( + num_layers=self.config.num_hidden_layers, + num_logical_experts=self.config.n_routed_experts, + ) + class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM): pass diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index d68fa489bee..86d1a709d9b 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -1821,6 +1821,40 @@ def resolve(self, forward_mode): return DeepEPMode.normal +T = TypeVar("T") + + +class Withable(Generic[T]): + def __init__(self): + self._value: Optional[T] = None + + @property + def value(self) -> T: + return self._value + + @contextmanager + def with_value(self, new_value: T): + assert self._value is None + self._value = new_value + try: + yield + finally: + assert self._value is new_value + self._value = None + + +@contextmanager +def configure_deep_gemm_num_sms(num_sms): + import deep_gemm + + original_num_sms = deep_gemm.get_num_sms() + deep_gemm.set_num_sms(num_sms) + try: + yield + finally: + deep_gemm.set_num_sms(original_num_sms) + + def fast_topk(values, topk, dim): if topk == 1: # Use max along the specified dimension to get both value and index @@ -1828,3 +1862,17 @@ def fast_topk(values, topk, dim): else: # Use topk for efficiency with larger k values return torch.topk(values, topk, dim=dim) + + +class DisposibleBox: + def __init__(self, value): + self._value = value + + @property + def value(self): + assert self._value is not None + return self._value + + def dispose(self): + assert self._value is not None + self._value = None From e3782854a016579faeb3f7a151527a274d558ce4 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 12 Apr 2025 08:02:33 +0800 Subject: [PATCH 0615/1089] cp back --- python/sglang/srt/models/deepseek_v2.py | 849 ++++++++---------------- 1 file changed, 292 insertions(+), 557 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index ccc9ed0619b..e7f65b52144 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -16,28 +16,22 @@ # https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py """Inference-only DeepseekV2 model.""" +import logging import os -from dataclasses import dataclass -from enum import Enum, auto -from functools import partial from typing import Any, Dict, Iterable, Optional, Tuple import torch import torch.nn.functional as F from torch import nn +from tqdm import tqdm from transformers import PretrainedConfig -from sglang.srt import two_batch_overlap from sglang.srt.distributed import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, parallel_state, tensor_model_parallel_all_reduce, ) from sglang.srt.layers.activation import SiluAndMul -from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import ( - decode_attention_fwd_grouped_rope, -) from sglang.srt.layers.dp_attention import ( dp_gather_partial, dp_scatter, @@ -56,12 +50,12 @@ ) from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE -from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.fp8_utils import ( block_quant_to_tensor_quant, + channel_quant_to_tensor_quant, input_to_float8, normalize_e4m3fn_to_e4m3fnuz, ) @@ -74,25 +68,19 @@ ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.managers.expert_distribution import expert_distribution_recorder -from sglang.srt.managers.expert_location import ExpertLocationMetadata +from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_loader.weight_utils import default_weight_loader -from sglang.srt.utils import ( - DeepEPMode, - add_prefix, - configure_deep_gemm_num_sms, - is_cuda, - is_cuda_available, - is_hip, -) +from sglang.srt.utils import DeepEPMode, add_prefix, is_cuda, is_hip _is_hip = is_hip() _is_cuda = is_cuda() if _is_cuda: from sgl_kernel import awq_dequantize, bmm_fp8 + + from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher else: from vllm import _custom_ops as ops @@ -101,6 +89,10 @@ decode_attention_fwd_grouped_rope, ) +expert_distribution_recorder = ExpertDistributionRecorder() + +logger = logging.getLogger(__name__) + class DeepseekV2MLP(nn.Module): def __init__( @@ -123,7 +115,7 @@ def __init__( prefix=add_prefix("gate_up_proj", prefix), tp_rank=tp_rank, tp_size=tp_size, - ) + ) self.down_proj = RowParallelLinear( intermediate_size, hidden_size, @@ -141,7 +133,7 @@ def __init__( ) self.act_fn = SiluAndMul() - def forward(self, x, forward_mode: Optional[ForwardMode] = None): + def forward(self, x): gate_up, _ = self.gate_up_proj(x) x = self.act_fn(gate_up) x, _ = self.down_proj(x) @@ -176,15 +168,19 @@ def __init__( self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, - layer_id: int = -999, prefix: str = "", ): super().__init__() self.tp_size = get_tensor_model_parallel_world_size() self.routed_scaling_factor = config.routed_scaling_factor self.n_shared_experts = config.n_shared_experts + self.n_share_experts_fusion = ( + global_server_args_dict["n_share_experts_fusion"] + if global_server_args_dict["n_share_experts_fusion"] is not None + else 0 + ) + self.routed_scaling_factor = config.routed_scaling_factor - self.layer_id = layer_id if self.tp_size > config.n_routed_experts: raise ValueError( f"Tensor parallel size {self.tp_size} is greater than " @@ -204,9 +200,10 @@ def __init__( if global_server_args_dict["enable_deepep_moe"] else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE) ) + self.experts = MoEImpl( - num_experts=config.n_routed_experts, - top_k=config.num_experts_per_tok, + num_experts=config.n_routed_experts + self.n_share_experts_fusion, + top_k=config.num_experts_per_tok + min(self.n_share_experts_fusion, 1), hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, renormalize=config.norm_topk_prob, @@ -223,7 +220,7 @@ def __init__( ), ) - if config.n_shared_experts is not None: + if config.n_shared_experts is not None and self.n_share_experts_fusion == 0: intermediate_size = config.moe_intermediate_size * config.n_shared_experts # disable tp for shared experts when enable deepep moe if not global_server_args_dict["enable_deepep_moe"]: @@ -261,27 +258,18 @@ def __init__( else None ) - self.deepep_dispatcher = self._create_deepep_dispatcher(config) - - if global_server_args_dict["enable_two_batch_overlap"]: - # TODO maybe we do not need to create 2+1 dispatchers, but can reuse the one above - self.tbo_deepep_dispatchers = [ - self._create_deepep_dispatcher(config) for i in range(2) - ] - - def _create_deepep_dispatcher(self, config): - return DeepEPDispatcher( - group=parallel_state.get_tp_group().device_group, - router_topk=self.top_k, - permute_fusion=True, - num_experts=config.n_routed_experts, - num_local_experts=config.n_routed_experts // self.tp_size, - hidden_size=config.hidden_size, - params_dtype=config.torch_dtype, - deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]], - async_finish=True, # TODO - return_recv_hook=True, - ) + self.deepep_dispatcher = DeepEPDispatcher( + group=parallel_state.get_tp_group().device_group, + router_topk=self.top_k, + permute_fusion=True, + num_experts=config.n_routed_experts, + num_local_experts=config.n_routed_experts // self.tp_size, + hidden_size=config.hidden_size, + params_dtype=config.torch_dtype, + deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]], + async_finish=True, # TODO + return_recv_hook=True, + ) def forward( self, hidden_states: torch.Tensor, forward_mode: Optional[ForwardMode] = None @@ -292,12 +280,13 @@ def forward( return self.forward_deepep(hidden_states, forward_mode) def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor: - if self.n_shared_experts is not None: - shared_output = self.shared_experts(hidden_states) + shared_output = self._forward_shared_experts(hidden_states) # router_logits: (num_tokens, n_experts) router_logits = self.gate(hidden_states) - final_hidden_states = self.experts(hidden_states=hidden_states, router_logits=router_logits) - final_hidden_states *= self.routed_scaling_factor + final_hidden_states = ( + self.experts(hidden_states=hidden_states, router_logits=router_logits) + * self.routed_scaling_factor + ) if shared_output is not None: final_hidden_states = final_hidden_states + shared_output if self.tp_size > 1: @@ -307,66 +296,7 @@ def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward_deepep( self, hidden_states: torch.Tensor, forward_mode: ForwardMode ) -> torch.Tensor: - shared_output = self._forward_deepep_shared_output(forward_mode, hidden_states) - - if ( - forward_mode is not None - and not forward_mode.is_idle() - and hidden_states.shape[0] > 0 - ): - # router_logits: (num_tokens, n_experts) - router_logits = self.gate(hidden_states) - else: - router_logits = None - - self._forward_deepep_dispatch_a( - self.deepep_dispatcher, forward_mode, hidden_states, router_logits - ) - ( - hidden_states, - topk_idx, - topk_weights, - reorder_topk_ids, - seg_indptr, - masked_m, - expected_m, - ) = self.deepep_dispatcher.dispatch_b() - - final_hidden_states = self._forward_deepep_expert( - hidden_states=hidden_states, - reorder_topk_ids=reorder_topk_ids, - seg_indptr=seg_indptr, - masked_m=masked_m, - expected_m=expected_m, - forward_mode=forward_mode, - ) - - if self.tp_size > 1: - final_hidden_states = self.deepep_dispatcher.combine( - hidden_states=final_hidden_states, - topk_idx=topk_idx, - topk_weights=topk_weights, - forward_mode=forward_mode, - ) - - if shared_output is not None: - final_hidden_states = final_hidden_states + shared_output - - return final_hidden_states - - def _forward_deepep_shared_output(self, forward_mode, hidden_states): - if ( - forward_mode is not None - and not forward_mode.is_idle() - and hidden_states.shape[0] > 0 - and self.n_shared_experts is not None - ): - return self.shared_experts(hidden_states) - return None - - def _forward_deepep_dispatch_a( - self, chosen_deepep_dispatcher, forward_mode, hidden_states, router_logits - ): + shared_output = None topk_idx = torch.full( (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device ) @@ -378,6 +308,9 @@ def _forward_deepep_dispatch_a( and not forward_mode.is_idle() and hidden_states.shape[0] > 0 ): + # router_logits: (num_tokens, n_experts) + router_logits = self.gate(hidden_states) + shared_output = self._forward_shared_experts(hidden_states) topk_weights, topk_idx = select_experts( hidden_states=hidden_states, router_logits=router_logits, @@ -388,101 +321,50 @@ def _forward_deepep_dispatch_a( num_expert_group=self.num_expert_group, correction_bias=self.correction_bias, ) - chosen_deepep_dispatcher.dispatch_a( - hidden_states, - topk_idx, - topk_weights, - self.num_experts, - forward_mode=forward_mode, - ) - - def _forward_deepep_expert( - self, - hidden_states, - reorder_topk_ids, - seg_indptr, - masked_m, - expected_m, - forward_mode, - ): - output = self.experts( - hidden_states=hidden_states, - reorder_topk_ids=reorder_topk_ids, - seg_indptr=seg_indptr, - masked_m=masked_m, - expected_m=expected_m, - forward_mode=forward_mode, - ) - output *= self.routed_scaling_factor - return output - - # ----------------------------------------- TBO-related -------------------------------------------- - - def _forward_tbo_op_gate(self, state): - state.router_logits = self.gate(state.hidden_states_after_post_attn_ln) - - def _forward_tbo_op_mlp(self, state): - state.expert_output_hidden_states = self._forward_deepep_expert( - hidden_states=state.hidden_states_from_dispatch, - reorder_topk_ids=state.reorder_topk_ids_from_dispatch, - seg_indptr=state.seg_indptr_from_dispatch, - masked_m=state.masked_m_from_dispatch, - expected_m=state.expected_m_from_dispatch, - forward_mode=state.forward_batch.forward_mode, - ) - - def _forward_tbo_op_dispatch_a(self, state): - self._forward_deepep_dispatch_a( - self.tbo_deepep_dispatchers[state.tbo_subbatch_index], - state.forward_batch.forward_mode, - state.hidden_states_after_post_attn_ln, - state.router_logits, - ) - - def _forward_tbo_op_dispatch_b(self, state, tbo_child_index: int): - dispatcher = self.tbo_deepep_dispatchers[state.tbo_subbatch_index] - with expert_distribution_recorder.with_current_layer( - self.layer_id - ), expert_distribution_recorder.with_debug_name( - ["child_a", "child_b"][tbo_child_index] - ): + if self.ep_size > 1: + # TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value ( - state.hidden_states_from_dispatch, - state.topk_idx_from_dispatch, - state.topk_weights_from_dispatch, - state.reorder_topk_ids_from_dispatch, - state.seg_indptr_from_dispatch, - state.masked_m_from_dispatch, - state.expected_m_from_dispatch, - ) = dispatcher.dispatch_b() - - def _forward_tbo_op_combine_a(self, state): - self.tbo_deepep_dispatchers[state.tbo_subbatch_index].combine_a( - hidden_states=state.expert_output_hidden_states, - topk_idx=state.topk_idx_from_dispatch, - topk_weights=state.topk_weights_from_dispatch, - forward_mode=state.forward_batch.forward_mode, + hidden_states, + topk_idx, + topk_weights, + reorder_topk_ids, + seg_indptr, + masked_m, + expected_m, + ) = self.deepep_dispatcher.dispatch( + hidden_states, + topk_idx, + topk_weights, + forward_mode=forward_mode, + ) + final_hidden_states = ( + self.experts( + hidden_states=hidden_states, + reorder_topk_ids=reorder_topk_ids, + seg_indptr=seg_indptr, + masked_m=masked_m, + expected_m=expected_m, + forward_mode=forward_mode, + ) + * self.routed_scaling_factor ) + if self.ep_size > 1: + final_hidden_states = self.deepep_dispatcher.combine( + final_hidden_states, + topk_idx, + topk_weights, + forward_mode, + ) + if shared_output is not None: + final_hidden_states = final_hidden_states + shared_output - def _forward_tbo_op_combine_b(self, state): - dispatcher = self.tbo_deepep_dispatchers[state.tbo_subbatch_index] - state.hidden_states_from_combine = dispatcher.combine_b() - - def _forward_tbo_op_shared(self, state): - state.shared_output = self._forward_deepep_shared_output( - state.forward_batch.forward_mode, state.hidden_states_after_post_attn_ln - ) + return final_hidden_states - def _forward_tbo_op_compute_layer_output(self, state): - output = dict( - positions=state.positions, - hidden_states=state.hidden_states_from_combine + state.shared_output, - forward_batch=state.forward_batch, - residual=state.residual_after_post_attn_ln, - tbo_subbatch_index=state.tbo_subbatch_index, - ) - state.clear() - return output + def _forward_shared_experts(self, hidden_states): + if self.n_shared_experts is not None and self.n_share_experts_fusion == 0: + return self.shared_experts(hidden_states) + else: + return None def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: @@ -549,7 +431,7 @@ def __init__( bias=False, quant_config=quant_config, prefix=add_prefix("q_b_proj", prefix), - ) + ) else: self.q_proj = ColumnParallelLinear( self.hidden_size, @@ -559,7 +441,7 @@ def __init__( prefix=add_prefix("q_proj", prefix), tp_rank=attn_tp_rank, tp_size=attn_tp_size, - ) + ) self.kv_a_proj_with_mqa = ReplicatedLinear( self.hidden_size, @@ -567,7 +449,7 @@ def __init__( bias=False, quant_config=quant_config, prefix=add_prefix("kv_a_proj_with_mqa", prefix), - ) + ) self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) self.kv_b_proj = ColumnParallelLinear( self.kv_lora_rank, @@ -575,7 +457,7 @@ def __init__( bias=False, quant_config=quant_config, prefix=add_prefix("kv_b_proj", prefix), - ) + ) # O projection. self.o_proj = RowParallelLinear( self.num_heads * self.v_head_dim, @@ -586,7 +468,7 @@ def __init__( reduce_results=reduce_results, tp_rank=attn_tp_rank, tp_size=attn_tp_size, - ) + ) rope_scaling["rope_type"] = "deepseek_yarn" self.rotary_emb = get_rope_wrapper( qk_rope_head_dim, @@ -611,6 +493,7 @@ def __init__( self.scaling, num_kv_heads=self.num_local_heads, layer_id=layer_id, + quant_config=quant_config, prefix=add_prefix("attn", prefix), ) @@ -659,8 +542,8 @@ def forward( ) attn_output = self.attn(q, k, v, forward_batch) attn_output = attn_output.view(-1, self.num_local_heads, 256)[ - ..., : self.v_head_dim - ].reshape(-1, self.num_local_heads * self.v_head_dim) + ..., : self.v_head_dim + ].reshape(-1, self.num_local_heads * self.v_head_dim) output, _ = self.o_proj(attn_output) return output @@ -723,7 +606,7 @@ def __init__( prefix=add_prefix("q_b_proj", prefix), tp_rank=attn_tp_rank, tp_size=attn_tp_size, - ) + ) else: self.q_proj = ColumnParallelLinear( self.hidden_size, @@ -733,7 +616,7 @@ def __init__( prefix=add_prefix("q_proj", prefix), tp_rank=attn_tp_rank, tp_size=attn_tp_size, - ) + ) self.kv_b_proj = ColumnParallelLinear( self.kv_lora_rank, self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), @@ -742,7 +625,7 @@ def __init__( prefix=add_prefix("kv_b_proj", prefix), tp_rank=attn_tp_rank, tp_size=attn_tp_size, - ) + ) # O projection. self.o_proj = RowParallelLinear( self.num_heads * self.v_head_dim, @@ -753,7 +636,7 @@ def __init__( prefix=add_prefix("o_proj", prefix), tp_rank=attn_tp_rank, tp_size=attn_tp_size, - ) + ) self.kv_a_proj_with_mqa = ReplicatedLinear( self.hidden_size, @@ -761,7 +644,7 @@ def __init__( bias=False, quant_config=quant_config, prefix=add_prefix("kv_a_proj_with_mqa", prefix), - ) + ) self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) if rope_scaling: @@ -791,8 +674,9 @@ def __init__( num_kv_heads=1, layer_id=layer_id, v_head_dim=self.kv_lora_rank, + quant_config=quant_config, prefix=add_prefix("attn_mqa", prefix), - ) + ) self.attn_mha = RadixAttention( self.num_local_heads, @@ -801,14 +685,14 @@ def __init__( num_kv_heads=self.num_local_heads, layer_id=layer_id, v_head_dim=self.v_head_dim, + quant_config=quant_config, prefix=add_prefix("attn_mha", prefix), - ) + ) self.w_kc = None self.w_vc = None self.w_scale = None - self.enable_flashinfer_mla = global_server_args_dict["enable_flashinfer_mla"] self.flashinfer_mla_disable_ragged = global_server_args_dict[ "flashinfer_mla_disable_ragged" ] @@ -816,7 +700,7 @@ def __init__( self.rocm_fused_decode_mla = os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1" def no_absorb(self, forward_batch: ForwardBatch) -> bool: - if self.enable_flashinfer_mla: + if self.attention_backend == "flashinfer": # Flashinfer MLA: Do not absorb when enabling ragged prefill return ( not self.flashinfer_mla_disable_ragged @@ -913,17 +797,6 @@ def forward_absorb( hidden_states: torch.Tensor, forward_batch: ForwardBatch, ) -> torch.Tensor: - state = self.forward_absorb_stage_prepare( - positions, hidden_states, forward_batch - ) - return self.forward_absorb_stage_core(state) - - def forward_absorb_stage_prepare( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - forward_batch: ForwardBatch, - ): q_len = hidden_states.shape[0] q_input = hidden_states.new_empty( q_len, self.num_local_heads, self.kv_lora_rank + self.qk_rope_head_dim @@ -943,7 +816,7 @@ def forward_absorb_stage_prepare( q_nope_out = torch.bmm( q_nope.to(torch.bfloat16).transpose(0, 1), self.w_kc.to(torch.bfloat16) * self.w_scale, - ) + ) elif self.w_kc.dtype == torch.float8_e4m3fn: q_nope_val, q_nope_scale = input_to_float8( q_nope.transpose(0, 1), torch.float8_e4m3fn @@ -966,11 +839,6 @@ def forward_absorb_stage_prepare( q_input[..., self.kv_lora_rank :] = q_pe k_input[..., self.kv_lora_rank :] = k_pe - return q_input, k_input, v_input, forward_batch - - def forward_absorb_stage_core(self, state) -> torch.Tensor: - q_input, k_input, v_input, forward_batch = state - attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch) attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank) @@ -979,7 +847,7 @@ def forward_absorb_stage_core(self, state) -> torch.Tensor: attn_bmm_output = torch.bmm( attn_output.to(torch.bfloat16).transpose(0, 1), self.w_vc.to(torch.bfloat16) * self.w_scale, - ) + ) elif self.w_vc.dtype == torch.float8_e4m3fn: attn_output_val, attn_output_scale = input_to_float8( attn_output.transpose(0, 1), torch.float8_e4m3fn @@ -1026,7 +894,7 @@ def forward_absorb_fused_mla_rope( q_nope_out = torch.bmm( q_nope.to(torch.bfloat16).transpose(0, 1), self.w_kc.to(torch.bfloat16) * self.w_scale, - ) + ) elif self.w_kc.dtype == torch.float8_e4m3fn: q_nope_val, q_nope_scale = input_to_float8( q_nope.transpose(0, 1), torch.float8_e4m3fn @@ -1122,7 +990,7 @@ def forward_absorb_fused_mla_rope( attn_bmm_output = torch.bmm( attn_output.to(torch.bfloat16).transpose(0, 1), self.w_vc.to(torch.bfloat16) * self.w_scale, - ) + ) elif self.w_vc.dtype == torch.float8_e4m3fn: attn_output_val, attn_output_scale = input_to_float8( attn_output.transpose(0, 1), torch.float8_e4m3fn @@ -1142,17 +1010,6 @@ def forward_absorb_fused_mla_rope( return output -class _DecoderLayerExecutionMode(Enum): - MLP_ONE = auto() - MLP_ALL = auto() - - -@dataclass -class _DecoderLayerInfo: - is_sparse: bool - execution_mode: _DecoderLayerExecutionMode - - class DeepseekV2DecoderLayer(nn.Module): def __init__( @@ -1163,6 +1020,14 @@ def __init__( is_nextn: bool = False, prefix: str = "", ) -> None: + + def is_sparse_layer(l: int): + return ( + config.n_routed_experts is not None + and l >= config.first_k_dense_replace + and l % config.moe_layer_freq == 0 + ) + super().__init__() self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) @@ -1215,35 +1080,26 @@ def __init__( prefix=add_prefix("self_attn", prefix), ) - self.info = self._compute_info(config, layer_id=layer_id, is_nextn=is_nextn) - previous_layer_info = self._compute_info( - config, layer_id=layer_id - 1, is_nextn=False - ) - - if self.info.is_sparse: + if is_nextn or is_sparse_layer(layer_id): self.mlp = DeepseekV2MoE( config=config, quant_config=quant_config, prefix=add_prefix("mlp", prefix), - layer_id=self.layer_id, ) + self.is_sparse = True else: - if self._enable_moe_dense_fully_dp(): - mlp_tp_rank, mlp_tp_size = 0, 1 - else: - mlp_tp_rank, mlp_tp_size = None, None self.mlp = DeepseekV2MLP( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, prefix=add_prefix("mlp", prefix), - tp_rank=mlp_tp_rank, - tp_size=mlp_tp_size, ) + self.is_sparse = False self.input_is_scattered = ( - previous_layer_info.execution_mode == _DecoderLayerExecutionMode.MLP_ONE + is_sparse_layer(layer_id - 1) + and global_server_args_dict["enable_deepep_moe"] ) self.is_last_layer = self.layer_id == config.num_hidden_layers - 1 @@ -1252,25 +1108,6 @@ def __init__( config.hidden_size, eps=config.rms_norm_eps ) - @staticmethod - def _enable_moe_dense_fully_dp(): - return global_server_args_dict["moe_dense_tp_size"] == 1 - - @staticmethod - def _compute_info(config: PretrainedConfig, layer_id: int, is_nextn: bool): - is_sparse = is_nextn or ( - config.n_routed_experts is not None - and layer_id >= config.first_k_dense_replace - and layer_id % config.moe_layer_freq == 0 - ) - execution_mode = ( - _DecoderLayerExecutionMode.MLP_ONE - if (global_server_args_dict["enable_deepep_moe"] and is_sparse) - or (DeepseekV2DecoderLayer._enable_moe_dense_fully_dp() and not is_sparse) - else _DecoderLayerExecutionMode.MLP_ALL - ) - return _DecoderLayerInfo(is_sparse=is_sparse, execution_mode=execution_mode) - def forward( self, positions: torch.Tensor, @@ -1278,18 +1115,16 @@ def forward( forward_batch: ForwardBatch, residual: Optional[torch.Tensor], ) -> torch.Tensor: - if self.info.execution_mode == _DecoderLayerExecutionMode.MLP_ONE: - return self.forward_mode_mlp_one( + if global_server_args_dict["enable_deepep_moe"] and self.is_sparse: + return self.forward_deepep( positions, hidden_states, forward_batch, residual ) - elif self.info.execution_mode == _DecoderLayerExecutionMode.MLP_ALL: - return self.forward_mode_mlp_all( + else: + return self.forward_normal( positions, hidden_states, forward_batch, residual ) - else: - raise NotImplementedError - def forward_mode_mlp_all( + def forward_normal( self, positions: torch.Tensor, hidden_states: torch.Tensor, @@ -1300,9 +1135,11 @@ def forward_mode_mlp_all( if hidden_states.shape[0] == 0: residual = hidden_states else: - hidden_states, residual = self._forward_input_layernorm( - hidden_states, residual - ) + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) assert not ( self.attn_tp_size != 1 and self.input_is_scattered @@ -1354,7 +1191,7 @@ def forward_mode_mlp_all( return hidden_states, residual - def forward_mode_mlp_one( + def forward_deepep( self, positions: torch.Tensor, hidden_states: torch.Tensor, @@ -1365,9 +1202,11 @@ def forward_mode_mlp_one( if hidden_states.shape[0] == 0: residual = hidden_states else: - hidden_states, residual = self._forward_input_layernorm( - hidden_states, residual - ) + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) if self.attn_tp_size != 1 and self.input_is_scattered: hidden_states, local_hidden_states = ( @@ -1408,13 +1247,7 @@ def forward_mode_mlp_one( hidden_states, residual = self.post_attention_layernorm( hidden_states, residual ) - - if not ( - self._enable_moe_dense_fully_dp() - and (not self.info.is_sparse) - and hidden_states.shape[0] == 0 - ): - hidden_states = self.mlp(hidden_states, forward_batch.forward_mode) + hidden_states = self.mlp(hidden_states, forward_batch.forward_mode) if self.is_last_layer and self.attn_tp_size != 1: hidden_states += residual @@ -1429,117 +1262,6 @@ def forward_mode_mlp_one( return hidden_states, residual - def _forward_input_layernorm(self, hidden_states, residual): - if residual is None: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - else: - hidden_states, residual = self.input_layernorm(hidden_states, residual) - return hidden_states, residual - - # ----------------------------------------- TBO-related -------------------------------------------- - - def get_forward_tbo_operations( - self, forward_mode: ForwardMode, tbo_child_index: int - ): - if forward_mode == ForwardMode.EXTEND: - operations = [ - self._forward_tbo_op_input_layernorm, - self._forward_tbo_op_prefill_attn, - self._forward_tbo_op_post_attn_layernorm, - self.mlp._forward_tbo_op_gate, - self.mlp._forward_tbo_op_dispatch_a, - two_batch_overlap.YieldOperation(), - partial( - self.mlp._forward_tbo_op_dispatch_b, tbo_child_index=tbo_child_index - ), - self.mlp._forward_tbo_op_mlp, - self.mlp._forward_tbo_op_combine_a, - two_batch_overlap.YieldOperation(), - self.mlp._forward_tbo_op_shared, - self.mlp._forward_tbo_op_combine_b, - self.mlp._forward_tbo_op_compute_layer_output, - ] - elif forward_mode == ForwardMode.DECODE: - operations = [ - self._forward_tbo_op_input_layernorm, - self._forward_tbo_op_decode_attn_0, - two_batch_overlap.YieldOperation(), - self._forward_tbo_op_decode_attn_1, - self._forward_tbo_op_post_attn_layernorm, - self.mlp._forward_tbo_op_gate, - two_batch_overlap.YieldOperation(), - self.mlp._forward_tbo_op_dispatch_a, - self.mlp._forward_tbo_op_shared, - two_batch_overlap.YieldOperation(), - partial( - self.mlp._forward_tbo_op_dispatch_b, tbo_child_index=tbo_child_index - ), - self.mlp._forward_tbo_op_mlp, - self.mlp._forward_tbo_op_combine_a, - two_batch_overlap.YieldOperation(), - self.mlp._forward_tbo_op_combine_b, - self.mlp._forward_tbo_op_compute_layer_output, - two_batch_overlap.YieldOperation(), - ] - else: - raise NotImplementedError(f"Unsupported {forward_mode=}") - return two_batch_overlap.decorate_operations( - operations, debug_name_prefix=f"L{self.layer_id}-" - ) - - def _forward_tbo_op_input_layernorm( - self, - state, - positions: torch.Tensor, - hidden_states: torch.Tensor, - forward_batch: ForwardBatch, - residual: Optional[torch.Tensor], - tbo_subbatch_index: int, - ): - state.hidden_states_after_input_ln, state.residual_after_input_ln = ( - self._forward_input_layernorm(hidden_states, residual) - ) - state.update( - dict( - forward_batch=forward_batch, - positions=positions, - tbo_subbatch_index=tbo_subbatch_index, - ) - ) - - def _forward_tbo_op_prefill_attn(self, state): - state.hidden_states_after_attn = self.self_attn( - positions=state.positions, - hidden_states=state.hidden_states_after_input_ln, - forward_batch=state.forward_batch, - ) - - def _forward_tbo_op_decode_attn_0(self, state): - state.self_attn_state = self.self_attn.forward_absorb_stage_prepare( - positions=state.positions, - hidden_states=state.hidden_states_after_input_ln, - forward_batch=state.forward_batch, - ) - - def _forward_tbo_op_decode_attn_1(self, state): - assert ( - (get_tensor_model_parallel_world_size() > 1) - and global_server_args_dict["enable_dp_attention"] - and global_server_args_dict["enable_deepep_moe"] - and isinstance(self.mlp, DeepseekV2MoE) - ) - state.hidden_states_after_attn = self.self_attn.forward_absorb_stage_core( - state.self_attn_state - ) - - def _forward_tbo_op_post_attn_layernorm(self, state): - state.hidden_states_after_post_attn_ln, state.residual_after_post_attn_ln = ( - self.post_attention_layernorm( - state.hidden_states_after_attn, state.residual_after_input_ln - ) - ) - class DeepseekV2Model(nn.Module): fall_back_to_pt_during_load = False @@ -1553,7 +1275,6 @@ def __init__( super().__init__() self.padding_id = config.pad_token_id self.vocab_size = config.vocab_size - self.first_k_dense_replace = config.first_k_dense_replace self.embed_tokens = VocabParallelEmbedding( config.vocab_size, @@ -1589,27 +1310,12 @@ def forward( hidden_states = input_embeds residual = None - - normal_num_layers = ( - self.first_k_dense_replace - if forward_batch.can_run_tbo - else len(self.layers) - ) - for i in range(normal_num_layers): - with expert_distribution_recorder.with_current_layer(i): - layer = self.layers[i] - hidden_states, residual = layer( - positions, hidden_states, forward_batch, residual - ) - - hidden_states, residual = self._forward_tbo_layers( - positions=positions, - forward_batch=forward_batch, - hidden_states=hidden_states, - residual=residual, - start_layer=normal_num_layers, - ) - + for i in range(len(self.layers)): + expert_distribution_recorder.set_current_layer(i) + layer = self.layers[i] + hidden_states, residual = layer( + positions, hidden_states, forward_batch, residual + ) if not forward_batch.forward_mode.is_idle(): if residual is None: hidden_states = self.norm(hidden_states) @@ -1617,47 +1323,6 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - def _forward_tbo_layers( - self, - positions: torch.Tensor, - forward_batch: ForwardBatch, - hidden_states: torch.Tensor, - residual: torch.Tensor, - start_layer: int, - ): - end_layer = len(self.layers) - if start_layer == end_layer: - return hidden_states, residual - - def compute_operations(tbo_child_index: str): - return [ - op - for i in range(start_layer, end_layer) - for op in self.layers[i].get_forward_tbo_operations( - forward_batch.forward_mode, tbo_child_index - ) - ] - - # TODO do not hardcode - chosen_num_sms = ( - torch.cuda.get_device_properties(device="cuda").multi_processor_count - 20 - ) - with configure_deep_gemm_num_sms(num_sms=chosen_num_sms): - return two_batch_overlap.model_forward_execute_two_batch( - inputs=dict( - positions=positions, - hidden_states=hidden_states, - forward_batch=forward_batch, - residual=residual, - ), - operations_a=compute_operations(0), - operations_b=compute_operations(1), - delta_stages={ - ForwardMode.EXTEND: 0, - ForwardMode.DECODE: 2, - }[forward_batch.forward_mode], - ) - class DeepseekV2ForCausalLM(nn.Module): @@ -1669,7 +1334,28 @@ def __init__( ) -> None: super().__init__() self.config = config + self.tp_size = get_tensor_model_parallel_world_size() self.quant_config = quant_config + self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"] + # Only Deepseek V3/R1 can use shared experts fusion optimization now. + if ( + global_server_args_dict.get("disable_shared_experts_fusion", False) + or self.config.architectures[0] != "DeepseekV3ForCausalLM" + or self.config.n_routed_experts != 256 + or self.config.routed_scaling_factor != 2.5 + ): + self.n_share_experts_fusion = None + global_server_args_dict["n_share_experts_fusion"] = None + logger.info( + "Only Deepseek V3/R1 can use shared experts fusion optimization. Shared experts fusion optimization is disabled." + ) + elif self.n_share_experts_fusion is None: + global_server_args_dict["n_share_experts_fusion"] = self.tp_size + self.n_share_experts_fusion = self.tp_size + logger.info( + f"Shared experts fusion optimization is default enabled in DeepSeek V3/R1, and n_share_experts_fusion is set to {self.tp_size}. You can tune it by setting --n_share_experts_fusion or disable it by setting --disable_shared_experts_fusion." + ) + self.model = DeepseekV2Model( config, quant_config, prefix=add_prefix("model", prefix) ) @@ -1700,12 +1386,134 @@ def forward( input_ids, hidden_states, self.lm_head, forward_batch ) + def post_load_weights(self): + + # Perform post-processing after loading weights + + if not global_server_args_dict["disable_mla"]: + for layer_id in range(self.config.num_hidden_layers): + self_attn = self.model.layers[layer_id].self_attn + if hasattr(self_attn.kv_b_proj, "qweight"): + # AWQ compatible + if _is_cuda: + w = awq_dequantize( + self_attn.kv_b_proj.qweight, + self_attn.kv_b_proj.scales, + self_attn.kv_b_proj.qzeros, + ).T + else: + w = ops.awq_dequantize( + self_attn.kv_b_proj.qweight, + self_attn.kv_b_proj.scales, + self_attn.kv_b_proj.qzeros, + 0, + 0, + 0, + ).T + else: + w = self_attn.kv_b_proj.weight + # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`. + # This may affect the accuracy of fp8 model. + if w.dtype in ( + torch.float8_e4m3fn, + torch.float8_e4m3fnuz, + ): + if hasattr(self.quant_config, "weight_block_size"): + weight_block_size = self.quant_config.weight_block_size + if weight_block_size is not None: + assert hasattr(self_attn.kv_b_proj, "weight_scale_inv") + if _is_hip: + weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( + weight=w, + weight_scale=self_attn.kv_b_proj.weight_scale_inv, + input_scale=None, + ) + else: + weight = w + weight_scale = self_attn.kv_b_proj.weight_scale_inv + + w, scale = block_quant_to_tensor_quant( + weight, weight_scale, weight_block_size + ) + self_attn.w_scale = scale + else: + weight = w + weight_scale = self_attn.kv_b_proj.weight_scale + w, scale = channel_quant_to_tensor_quant(weight, weight_scale) + self_attn.w_scale = scale + + if w.dtype == torch.int8: + if hasattr(self.quant_config, "weight_block_size"): + # block-wise int8 need it + weight_block_size = self.quant_config.weight_block_size + if weight_block_size is not None: + assert hasattr(self_attn.kv_b_proj, "weight_scale_inv") + weight = w + weight_scale = self_attn.kv_b_proj.weight_scale_inv + w = int8_block_dequant( + weight, weight_scale, weight_block_size + ).to(torch.bfloat16) + else: + # channel-wise int8 need it + w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to( + torch.bfloat16 + ) + w_kc, w_vc = w.unflatten( + 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim) + ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1) + self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2) + self_attn.w_vc = w_vc.contiguous().transpose(1, 2) + if ( + hasattr(self_attn.kv_b_proj, "weight_scale") + and self_attn.w_scale is None + ): + self_attn.w_scale = self_attn.kv_b_proj.weight_scale + if _is_hip: + self_attn.w_scale *= 2.0 + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), ] + if self.n_share_experts_fusion is not None and self.n_share_experts_fusion > 0: + weights_list = list(weights) + weights_dict = dict(weights_list) + suffix_list = [ + "down_proj.weight", + "down_proj.weight_scale_inv", + "gate_proj.weight", + "gate_proj.weight_scale_inv", + "up_proj.weight", + "up_proj.weight_scale_inv", + ] + names_to_remove = [] + for moe_layer in tqdm( + range( + self.config.first_k_dense_replace, + self.config.num_hidden_layers, + self.config.moe_layer_freq, + ), + desc=f"Cloning {self.n_share_experts_fusion} " + "replicas of the shared expert into MoE", + ): + for num_repeat in range(self.n_share_experts_fusion): + for suffix in suffix_list: + shared_expert_weight_name = ( + f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}" + ) + weights_list.append( + ( + f"model.layers.{moe_layer}." + f"mlp.experts." + f"{self.config.n_routed_experts + num_repeat}" + f".{suffix}", + weights_dict[shared_expert_weight_name].clone(), + ) + ) + names_to_remove += [shared_expert_weight_name] + weights = [w for w in weights_list if w[0] not in names_to_remove] # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) @@ -1718,7 +1526,12 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", - num_experts=self.config.n_routed_experts, + num_experts=self.config.n_routed_experts + + ( + self.n_share_experts_fusion + if self.n_share_experts_fusion is not None + else 0 + ), ) params_dict = dict(self.named_parameters()) @@ -1782,79 +1595,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ) weight_loader(param, loaded_weight) - if not global_server_args_dict["disable_mla"]: - for layer_id in range(self.config.num_hidden_layers): - self_attn = self.model.layers[layer_id].self_attn - if hasattr(self_attn.kv_b_proj, "qweight"): - # AWQ compatible - if _is_cuda: - w = awq_dequantize( - self_attn.kv_b_proj.qweight, - self_attn.kv_b_proj.scales, - self_attn.kv_b_proj.qzeros, - ).T - else: - w = ops.awq_dequantize( - self_attn.kv_b_proj.qweight, - self_attn.kv_b_proj.scales, - self_attn.kv_b_proj.qzeros, - 0, - 0, - 0, - ).T - else: - w = self_attn.kv_b_proj.weight - # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`. - # This may affect the accuracy of fp8 model. - if hasattr(self.quant_config, "weight_block_size") and w.dtype in ( - torch.float8_e4m3fn, - torch.float8_e4m3fnuz, - ): - weight_block_size = self.quant_config.weight_block_size - if weight_block_size is not None: - assert hasattr(self_attn.kv_b_proj, "weight_scale_inv") - if _is_hip: - weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( - weight=w, - weight_scale=self_attn.kv_b_proj.weight_scale_inv, - input_scale=None, - ) - else: - weight = w - weight_scale = self_attn.kv_b_proj.weight_scale_inv - - w, scale = block_quant_to_tensor_quant( - weight, weight_scale, weight_block_size - ) - self_attn.w_scale = scale - if w.dtype == torch.int8: - if hasattr(self.quant_config, "weight_block_size"): - # block-wise int8 need it - weight_block_size = self.quant_config.weight_block_size - if weight_block_size is not None: - assert hasattr(self_attn.kv_b_proj, "weight_scale_inv") - weight = w - weight_scale = self_attn.kv_b_proj.weight_scale_inv - w = int8_block_dequant( - weight, weight_scale, weight_block_size - ).to(torch.bfloat16) - else: - # channel-wise int8 need it - w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to( - torch.bfloat16 - ) - w_kc, w_vc = w.unflatten( - 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim) - ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1) - self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2) - self_attn.w_vc = w_vc.contiguous().transpose(1, 2) - if ( - hasattr(self_attn.kv_b_proj, "weight_scale") - and self_attn.w_scale is None - ): - self_attn.w_scale = self_attn.kv_b_proj.weight_scale - if _is_hip: - self_attn.w_scale *= 2.0 + self.post_load_weights() def get_embed_and_head(self): return self.model.embed_tokens.weight, self.lm_head.weight @@ -1867,12 +1608,6 @@ def set_embed_and_head(self, embed, head): torch.cuda.empty_cache() torch.cuda.synchronize() - def get_expert_location_metadata(self): - return ExpertLocationMetadata.init_new( - num_layers=self.config.num_hidden_layers, - num_logical_experts=self.config.n_routed_experts, - ) - class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM): pass From a8ebfe9a82cfb1950885b6582b04ecad934e0c4a Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 12 Apr 2025 08:02:55 +0800 Subject: [PATCH 0616/1089] cp --- python/sglang/srt/models/deepseek_v2.py | 85 ++++++++++++------------- 1 file changed, 41 insertions(+), 44 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index e7f65b52144..4a1698db991 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -22,10 +22,6 @@ import torch import torch.nn.functional as F -from torch import nn -from tqdm import tqdm -from transformers import PretrainedConfig - from sglang.srt.distributed import ( get_tensor_model_parallel_world_size, parallel_state, @@ -73,6 +69,9 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import DeepEPMode, add_prefix, is_cuda, is_hip +from torch import nn +from tqdm import tqdm +from transformers import PretrainedConfig _is_hip = is_hip() _is_cuda = is_cuda() @@ -115,7 +114,7 @@ def __init__( prefix=add_prefix("gate_up_proj", prefix), tp_rank=tp_rank, tp_size=tp_size, - ) + ) self.down_proj = RowParallelLinear( intermediate_size, hidden_size, @@ -283,10 +282,8 @@ def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor: shared_output = self._forward_shared_experts(hidden_states) # router_logits: (num_tokens, n_experts) router_logits = self.gate(hidden_states) - final_hidden_states = ( - self.experts(hidden_states=hidden_states, router_logits=router_logits) - * self.routed_scaling_factor - ) + final_hidden_states = self.experts(hidden_states=hidden_states, router_logits=router_logits) + final_hidden_states *= self.routed_scaling_factor if shared_output is not None: final_hidden_states = final_hidden_states + shared_output if self.tp_size > 1: @@ -412,7 +409,7 @@ def __init__( self.num_heads = num_heads assert num_heads % attn_tp_size == 0 self.num_local_heads = num_heads // attn_tp_size - self.scaling = self.qk_head_dim**-0.5 + self.scaling = self.qk_head_dim ** -0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings @@ -431,7 +428,7 @@ def __init__( bias=False, quant_config=quant_config, prefix=add_prefix("q_b_proj", prefix), - ) + ) else: self.q_proj = ColumnParallelLinear( self.hidden_size, @@ -441,7 +438,7 @@ def __init__( prefix=add_prefix("q_proj", prefix), tp_rank=attn_tp_rank, tp_size=attn_tp_size, - ) + ) self.kv_a_proj_with_mqa = ReplicatedLinear( self.hidden_size, @@ -449,7 +446,7 @@ def __init__( bias=False, quant_config=quant_config, prefix=add_prefix("kv_a_proj_with_mqa", prefix), - ) + ) self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) self.kv_b_proj = ColumnParallelLinear( self.kv_lora_rank, @@ -457,7 +454,7 @@ def __init__( bias=False, quant_config=quant_config, prefix=add_prefix("kv_b_proj", prefix), - ) + ) # O projection. self.o_proj = RowParallelLinear( self.num_heads * self.v_head_dim, @@ -468,7 +465,7 @@ def __init__( reduce_results=reduce_results, tp_rank=attn_tp_rank, tp_size=attn_tp_size, - ) + ) rope_scaling["rope_type"] = "deepseek_yarn" self.rotary_emb = get_rope_wrapper( qk_rope_head_dim, @@ -525,12 +522,12 @@ def forward( kv = self.kv_b_proj(kv_a)[0] kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - k_pe = latent_cache[:, :, self.kv_lora_rank :] + k_pe = latent_cache[:, :, self.kv_lora_rank:] q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - q[..., self.qk_nope_head_dim :] = q_pe + q[..., self.qk_nope_head_dim:] = q_pe k = torch.empty_like(q) k[..., : self.qk_nope_head_dim] = k_nope - k[..., self.qk_nope_head_dim :] = k_pe + k[..., self.qk_nope_head_dim:] = k_pe q = torch.nn.functional.pad(q, [0, 256 - self.qk_head_dim], value=0).view( -1, self.num_local_heads * 256 ) @@ -584,7 +581,7 @@ def __init__( self.num_heads = num_heads assert num_heads % attn_tp_size == 0 self.num_local_heads = num_heads // attn_tp_size - self.scaling = self.qk_head_dim**-0.5 + self.scaling = self.qk_head_dim ** -0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings @@ -606,7 +603,7 @@ def __init__( prefix=add_prefix("q_b_proj", prefix), tp_rank=attn_tp_rank, tp_size=attn_tp_size, - ) + ) else: self.q_proj = ColumnParallelLinear( self.hidden_size, @@ -616,7 +613,7 @@ def __init__( prefix=add_prefix("q_proj", prefix), tp_rank=attn_tp_rank, tp_size=attn_tp_size, - ) + ) self.kv_b_proj = ColumnParallelLinear( self.kv_lora_rank, self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), @@ -625,7 +622,7 @@ def __init__( prefix=add_prefix("kv_b_proj", prefix), tp_rank=attn_tp_rank, tp_size=attn_tp_size, - ) + ) # O projection. self.o_proj = RowParallelLinear( self.num_heads * self.v_head_dim, @@ -636,7 +633,7 @@ def __init__( prefix=add_prefix("o_proj", prefix), tp_rank=attn_tp_rank, tp_size=attn_tp_size, - ) + ) self.kv_a_proj_with_mqa = ReplicatedLinear( self.hidden_size, @@ -644,7 +641,7 @@ def __init__( bias=False, quant_config=quant_config, prefix=add_prefix("kv_a_proj_with_mqa", prefix), - ) + ) self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) if rope_scaling: @@ -676,7 +673,7 @@ def __init__( v_head_dim=self.kv_lora_rank, quant_config=quant_config, prefix=add_prefix("attn_mqa", prefix), - ) + ) self.attn_mha = RadixAttention( self.num_local_heads, @@ -687,7 +684,7 @@ def __init__( v_head_dim=self.v_head_dim, quant_config=quant_config, prefix=add_prefix("attn_mha", prefix), - ) + ) self.w_kc = None self.w_vc = None @@ -771,16 +768,16 @@ def forward_normal( kv = self.kv_b_proj(kv_a)[0] kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) k_nope = kv[..., : self.qk_nope_head_dim] - v = kv[..., self.qk_nope_head_dim :] - k_pe = latent_cache[:, :, self.kv_lora_rank :] + v = kv[..., self.qk_nope_head_dim:] + k_pe = latent_cache[:, :, self.kv_lora_rank:] q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - q[..., self.qk_nope_head_dim :] = q_pe + q[..., self.qk_nope_head_dim:] = q_pe k = torch.empty_like(q) k[..., : self.qk_nope_head_dim] = k_nope - k[..., self.qk_nope_head_dim :] = k_pe + k[..., self.qk_nope_head_dim:] = k_pe latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1) - latent_cache[:, :, self.kv_lora_rank :] = k_pe + latent_cache[:, :, self.kv_lora_rank:] = k_pe # Save latent cache forward_batch.token_to_kv_pool.set_kv_buffer( @@ -816,7 +813,7 @@ def forward_absorb( q_nope_out = torch.bmm( q_nope.to(torch.bfloat16).transpose(0, 1), self.w_kc.to(torch.bfloat16) * self.w_scale, - ) + ) elif self.w_kc.dtype == torch.float8_e4m3fn: q_nope_val, q_nope_scale = input_to_float8( q_nope.transpose(0, 1), torch.float8_e4m3fn @@ -833,11 +830,11 @@ def forward_absorb( v_input = self.kv_a_layernorm(v_input.contiguous()).unsqueeze(1) k_input = latent_cache.unsqueeze(1) k_input[..., : self.kv_lora_rank] = v_input - k_pe = k_input[..., self.kv_lora_rank :] + k_pe = k_input[..., self.kv_lora_rank:] q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - q_input[..., self.kv_lora_rank :] = q_pe - k_input[..., self.kv_lora_rank :] = k_pe + q_input[..., self.kv_lora_rank:] = q_pe + k_input[..., self.kv_lora_rank:] = k_pe attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch) attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank) @@ -847,7 +844,7 @@ def forward_absorb( attn_bmm_output = torch.bmm( attn_output.to(torch.bfloat16).transpose(0, 1), self.w_vc.to(torch.bfloat16) * self.w_scale, - ) + ) elif self.w_vc.dtype == torch.float8_e4m3fn: attn_output_val, attn_output_scale = input_to_float8( attn_output.transpose(0, 1), torch.float8_e4m3fn @@ -894,7 +891,7 @@ def forward_absorb_fused_mla_rope( q_nope_out = torch.bmm( q_nope.to(torch.bfloat16).transpose(0, 1), self.w_kc.to(torch.bfloat16) * self.w_scale, - ) + ) elif self.w_kc.dtype == torch.float8_e4m3fn: q_nope_val, q_nope_scale = input_to_float8( q_nope.transpose(0, 1), torch.float8_e4m3fn @@ -913,15 +910,15 @@ def forward_absorb_fused_mla_rope( k_input[..., : self.kv_lora_rank] = v_input if not enable_rope_fusion: - k_pe = k_input[..., self.kv_lora_rank :] + k_pe = k_input[..., self.kv_lora_rank:] q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - q_input[..., self.kv_lora_rank :] = q_pe - k_input[..., self.kv_lora_rank :] = k_pe + q_input[..., self.kv_lora_rank:] = q_pe + k_input[..., self.kv_lora_rank:] = k_pe k_pe_output = None else: - k_pe_output = torch.empty_like(k_input[..., self.kv_lora_rank :]) + k_pe_output = torch.empty_like(k_input[..., self.kv_lora_rank:]) - q_input[..., self.kv_lora_rank :] = q_pe + q_input[..., self.kv_lora_rank:] = q_pe # attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch) # Use Fused ROPE with use_rope=OFF. @@ -978,7 +975,7 @@ def forward_absorb_fused_mla_rope( ) if enable_rope_fusion: - k_input[..., self.kv_lora_rank :] = k_pe_output + k_input[..., self.kv_lora_rank:] = k_pe_output forward_batch.token_to_kv_pool.set_kv_buffer( self.attn_mqa, forward_batch.out_cache_loc, k_input, None ) @@ -990,7 +987,7 @@ def forward_absorb_fused_mla_rope( attn_bmm_output = torch.bmm( attn_output.to(torch.bfloat16).transpose(0, 1), self.w_vc.to(torch.bfloat16) * self.w_scale, - ) + ) elif self.w_vc.dtype == torch.float8_e4m3fn: attn_output_val, attn_output_scale = input_to_float8( attn_output.transpose(0, 1), torch.float8_e4m3fn From b26010d15af0aaebee722fd6db85edf3a8c36c96 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 12 Apr 2025 08:03:12 +0800 Subject: [PATCH 0617/1089] fmt --- python/sglang/srt/models/deepseek_v2.py | 65 +++++++++++++------------ 1 file changed, 34 insertions(+), 31 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 4a1698db991..8ddc24d2b37 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -22,6 +22,10 @@ import torch import torch.nn.functional as F +from torch import nn +from tqdm import tqdm +from transformers import PretrainedConfig + from sglang.srt.distributed import ( get_tensor_model_parallel_world_size, parallel_state, @@ -69,9 +73,6 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import DeepEPMode, add_prefix, is_cuda, is_hip -from torch import nn -from tqdm import tqdm -from transformers import PretrainedConfig _is_hip = is_hip() _is_cuda = is_cuda() @@ -282,7 +283,9 @@ def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor: shared_output = self._forward_shared_experts(hidden_states) # router_logits: (num_tokens, n_experts) router_logits = self.gate(hidden_states) - final_hidden_states = self.experts(hidden_states=hidden_states, router_logits=router_logits) + final_hidden_states = self.experts( + hidden_states=hidden_states, router_logits=router_logits + ) final_hidden_states *= self.routed_scaling_factor if shared_output is not None: final_hidden_states = final_hidden_states + shared_output @@ -409,7 +412,7 @@ def __init__( self.num_heads = num_heads assert num_heads % attn_tp_size == 0 self.num_local_heads = num_heads // attn_tp_size - self.scaling = self.qk_head_dim ** -0.5 + self.scaling = self.qk_head_dim**-0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings @@ -522,12 +525,12 @@ def forward( kv = self.kv_b_proj(kv_a)[0] kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - k_pe = latent_cache[:, :, self.kv_lora_rank:] + k_pe = latent_cache[:, :, self.kv_lora_rank :] q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - q[..., self.qk_nope_head_dim:] = q_pe + q[..., self.qk_nope_head_dim :] = q_pe k = torch.empty_like(q) k[..., : self.qk_nope_head_dim] = k_nope - k[..., self.qk_nope_head_dim:] = k_pe + k[..., self.qk_nope_head_dim :] = k_pe q = torch.nn.functional.pad(q, [0, 256 - self.qk_head_dim], value=0).view( -1, self.num_local_heads * 256 ) @@ -539,8 +542,8 @@ def forward( ) attn_output = self.attn(q, k, v, forward_batch) attn_output = attn_output.view(-1, self.num_local_heads, 256)[ - ..., : self.v_head_dim - ].reshape(-1, self.num_local_heads * self.v_head_dim) + ..., : self.v_head_dim + ].reshape(-1, self.num_local_heads * self.v_head_dim) output, _ = self.o_proj(attn_output) return output @@ -581,7 +584,7 @@ def __init__( self.num_heads = num_heads assert num_heads % attn_tp_size == 0 self.num_local_heads = num_heads // attn_tp_size - self.scaling = self.qk_head_dim ** -0.5 + self.scaling = self.qk_head_dim**-0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings @@ -768,16 +771,16 @@ def forward_normal( kv = self.kv_b_proj(kv_a)[0] kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) k_nope = kv[..., : self.qk_nope_head_dim] - v = kv[..., self.qk_nope_head_dim:] - k_pe = latent_cache[:, :, self.kv_lora_rank:] + v = kv[..., self.qk_nope_head_dim :] + k_pe = latent_cache[:, :, self.kv_lora_rank :] q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - q[..., self.qk_nope_head_dim:] = q_pe + q[..., self.qk_nope_head_dim :] = q_pe k = torch.empty_like(q) k[..., : self.qk_nope_head_dim] = k_nope - k[..., self.qk_nope_head_dim:] = k_pe + k[..., self.qk_nope_head_dim :] = k_pe latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1) - latent_cache[:, :, self.kv_lora_rank:] = k_pe + latent_cache[:, :, self.kv_lora_rank :] = k_pe # Save latent cache forward_batch.token_to_kv_pool.set_kv_buffer( @@ -830,11 +833,11 @@ def forward_absorb( v_input = self.kv_a_layernorm(v_input.contiguous()).unsqueeze(1) k_input = latent_cache.unsqueeze(1) k_input[..., : self.kv_lora_rank] = v_input - k_pe = k_input[..., self.kv_lora_rank:] + k_pe = k_input[..., self.kv_lora_rank :] q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - q_input[..., self.kv_lora_rank:] = q_pe - k_input[..., self.kv_lora_rank:] = k_pe + q_input[..., self.kv_lora_rank :] = q_pe + k_input[..., self.kv_lora_rank :] = k_pe attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch) attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank) @@ -910,15 +913,15 @@ def forward_absorb_fused_mla_rope( k_input[..., : self.kv_lora_rank] = v_input if not enable_rope_fusion: - k_pe = k_input[..., self.kv_lora_rank:] + k_pe = k_input[..., self.kv_lora_rank :] q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - q_input[..., self.kv_lora_rank:] = q_pe - k_input[..., self.kv_lora_rank:] = k_pe + q_input[..., self.kv_lora_rank :] = q_pe + k_input[..., self.kv_lora_rank :] = k_pe k_pe_output = None else: - k_pe_output = torch.empty_like(k_input[..., self.kv_lora_rank:]) + k_pe_output = torch.empty_like(k_input[..., self.kv_lora_rank :]) - q_input[..., self.kv_lora_rank:] = q_pe + q_input[..., self.kv_lora_rank :] = q_pe # attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch) # Use Fused ROPE with use_rope=OFF. @@ -975,7 +978,7 @@ def forward_absorb_fused_mla_rope( ) if enable_rope_fusion: - k_input[..., self.kv_lora_rank:] = k_pe_output + k_input[..., self.kv_lora_rank :] = k_pe_output forward_batch.token_to_kv_pool.set_kv_buffer( self.attn_mqa, forward_batch.out_cache_loc, k_input, None ) @@ -1493,7 +1496,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): self.config.moe_layer_freq, ), desc=f"Cloning {self.n_share_experts_fusion} " - "replicas of the shared expert into MoE", + "replicas of the shared expert into MoE", ): for num_repeat in range(self.n_share_experts_fusion): for suffix in suffix_list: @@ -1524,11 +1527,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", num_experts=self.config.n_routed_experts - + ( - self.n_share_experts_fusion - if self.n_share_experts_fusion is not None - else 0 - ), + + ( + self.n_share_experts_fusion + if self.n_share_experts_fusion is not None + else 0 + ), ) params_dict = dict(self.named_parameters()) From 65c08f488317ebe89b64c4c5189d56cf63876b7d Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 12 Apr 2025 08:04:02 +0800 Subject: [PATCH 0618/1089] more --- python/sglang/srt/utils.py | 40 +++++++++++++++++++++----------------- 1 file changed, 22 insertions(+), 18 deletions(-) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 86d1a709d9b..4296c703742 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -54,10 +54,10 @@ import torch.distributed as dist import triton import zmq +from PIL import Image from decord import VideoReader, cpu from fastapi.responses import ORJSONResponse from packaging import version as pkg_version -from PIL import Image from starlette.routing import Mount from torch import nn from torch.func import functional_call @@ -914,10 +914,10 @@ def get_zmq_socket( context: zmq.Context, socket_type: zmq.SocketType, endpoint: str, bind: bool ): mem = psutil.virtual_memory() - total_mem = mem.total / 1024**3 - available_mem = mem.available / 1024**3 + total_mem = mem.total / 1024 ** 3 + available_mem = mem.available / 1024 ** 3 if total_mem > 32 and available_mem > 16: - buf_size = int(0.5 * 1024**3) + buf_size = int(0.5 * 1024 ** 3) else: buf_size = -1 @@ -1437,10 +1437,10 @@ def dataclass_to_string_truncated( return ( "{" + ", ".join( - f"'{k}': {dataclass_to_string_truncated(v, max_length)}" - for k, v in data.items() - if k not in skip_names - ) + f"'{k}': {dataclass_to_string_truncated(v, max_length)}" + for k, v in data.items() + if k not in skip_names + ) + "}" ) elif dataclasses.is_dataclass(data): @@ -1448,10 +1448,10 @@ def dataclass_to_string_truncated( return ( f"{data.__class__.__name__}(" + ", ".join( - f"{f.name}={dataclass_to_string_truncated(getattr(data, f.name), max_length)}" - for f in fields - if f.name not in skip_names - ) + f"{f.name}={dataclass_to_string_truncated(getattr(data, f.name), max_length)}" + for f in fields + if f.name not in skip_names + ) + ")" ) else: @@ -1641,7 +1641,7 @@ def configure_ipv6(dist_init_addr): port_str = None if len(addr) > end + 1: if addr[end + 1] == ":": - port_str = addr[end + 2 :] + port_str = addr[end + 2:] else: raise ValueError("received IPv6 address format: expected ':' after ']'") @@ -1779,7 +1779,7 @@ def retry( if not should_retry(e): raise Exception(f"retry() observe errors that should not be retried.") - delay = min(initial_delay * (2**try_index), max_delay) * ( + delay = min(initial_delay * (2 ** try_index), max_delay) * ( 0.75 + 0.25 * random.random() ) @@ -1864,15 +1864,19 @@ def fast_topk(values, topk, dim): return torch.topk(values, topk, dim=dim) -class DisposibleBox: - def __init__(self, value): +class DisposibleTensor: + def __init__(self, value: torch.Tensor): self._value = value @property def value(self): - assert self._value is not None + assert not self.is_disposed return self._value def dispose(self): - assert self._value is not None + assert not self.is_disposed self._value = None + + @property + def is_disposed(self): + return self._value is None From c24c212cefa399fb5c85521367af4bd8a0419337 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 12 Apr 2025 08:04:39 +0800 Subject: [PATCH 0619/1089] more --- python/sglang/srt/utils.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 4296c703742..df7ae77a708 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -1880,3 +1880,9 @@ def dispose(self): @property def is_disposed(self): return self._value is None + + @staticmethod + def maybe_unwrap(value: Union[torch.Tensor, "DisposibleTensor"]) -> torch.Tensor: + if isinstance(value, DisposibleTensor): + return value.value + return value From 8dccb009c098cdd715f2b54a0d531f6eb2be1f41 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 12 Apr 2025 08:05:29 +0800 Subject: [PATCH 0620/1089] more --- python/sglang/srt/utils.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index df7ae77a708..5dca0898855 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -1886,3 +1886,14 @@ def maybe_unwrap(value: Union[torch.Tensor, "DisposibleTensor"]) -> torch.Tensor if isinstance(value, DisposibleTensor): return value.value return value + + +class TensorCreator: + def __init__(self, creator: Callable[[], torch.Tensor]): + self._creator = creator + + def create(self): + assert self._creator is not None + value = self._creator() + self._creator = None + return value From 2a4c6301144c7acaf580634b128a0a06d0eef330 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 12 Apr 2025 08:05:59 +0800 Subject: [PATCH 0621/1089] more --- python/sglang/srt/utils.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 5dca0898855..e209b8fb36f 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -1897,3 +1897,9 @@ def create(self): value = self._creator() self._creator = None return value + + @staticmethod + def maybe_create(value: Union[torch.Tensor, "TensorCreator"]) -> torch.Tensor: + if isinstance(value, TensorCreator): + return value.create() + return value From 86f1c995eaa8aa0bd12f311de82abfe0422bc623 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 12 Apr 2025 08:06:38 +0800 Subject: [PATCH 0622/1089] more --- .../sglang/srt/layers/moe/ep_moe/kernels.py | 4 +- python/sglang/srt/layers/moe/ep_moe/layer.py | 238 +++++++++--------- .../srt/layers/moe/ep_moe/token_dispatcher.py | 193 +++++++------- 3 files changed, 217 insertions(+), 218 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/kernels.py b/python/sglang/srt/layers/moe/ep_moe/kernels.py index 87b9bea52bf..3f2851105c9 100644 --- a/python/sglang/srt/layers/moe/ep_moe/kernels.py +++ b/python/sglang/srt/layers/moe/ep_moe/kernels.py @@ -7,7 +7,7 @@ from sglang.srt.distributed import get_tensor_model_parallel_rank from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8 -from sglang.srt.utils import DisposibleBox, is_cuda +from sglang.srt.utils import DisposibleTensor, is_cuda _is_cuda = is_cuda() if _is_cuda: @@ -655,7 +655,7 @@ def grouped_gemm_triton( assert len(block_shape) == 2 block_n, block_k = block_shape[0], block_shape[1] if _is_cuda: - if isinstance(a, DisposibleBox): + if isinstance(a, DisposibleTensor): a_box = a a, scale_a = sglang_per_token_group_quant_fp8(a.value, block_k) a_box.dispose() diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 44df6b665eb..65b4618911d 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -40,7 +40,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.utils import ( DeepEPMode, - DisposibleBox, + DisposibleTensor, is_cuda, is_hip, set_weight_attrs, @@ -81,18 +81,18 @@ def _init_flashinfer_wrapper(cls, device): # c = a * b def forward( - self, - a: torch.Tensor, - b: torch.Tensor, - c: torch.Tensor, - batch_size: int, - weight_column_major: bool, - seg_indptr: Optional[torch.Tensor] = None, - weight_indices: Optional[torch.Tensor] = None, - use_fp8_w8a8: bool = False, - scale_a: torch.Tensor = None, - scale_b: torch.Tensor = None, - block_shape: Optional[List[int]] = None, + self, + a: torch.Tensor, + b: torch.Tensor, + c: torch.Tensor, + batch_size: int, + weight_column_major: bool, + seg_indptr: Optional[torch.Tensor] = None, + weight_indices: Optional[torch.Tensor] = None, + use_fp8_w8a8: bool = False, + scale_a: torch.Tensor = None, + scale_b: torch.Tensor = None, + block_shape: Optional[List[int]] = None, ): if self.use_flashinfer: # TODO: flashinfer @@ -132,22 +132,22 @@ class EPMoE(torch.nn.Module): """ def __init__( - self, - num_experts: int, - top_k: int, - hidden_size: int, - intermediate_size: int, - params_dtype: Optional[torch.dtype] = None, - renormalize: bool = True, - use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, - topk_group: Optional[int] = None, - quant_config: Optional[QuantizationConfig] = None, - tp_size: Optional[int] = None, - prefix: str = "", - correction_bias: Optional[torch.Tensor] = None, - custom_routing_function: Optional[Callable] = None, - activation: str = "silu", + self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None, + tp_size: Optional[int] = None, + prefix: str = "", + correction_bias: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + activation: str = "silu", ): super().__init__() @@ -264,7 +264,7 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): BLOCK_SIZE=512, ) - seg_indptr_cur_rank = seg_indptr[self.start_expert_id : self.end_expert_id + 2] + seg_indptr_cur_rank = seg_indptr[self.start_expert_id: self.end_expert_id + 2] weight_indices_cur_rank = torch.arange( 0, self.num_experts_per_partition, @@ -382,11 +382,11 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): @classmethod def make_expert_params_mapping( - cls, - ckpt_gate_proj_name: str, - ckpt_down_proj_name: str, - ckpt_up_proj_name: str, - num_experts: int, + cls, + ckpt_gate_proj_name: str, + ckpt_down_proj_name: str, + ckpt_up_proj_name: str, + num_experts: int, ) -> List[Tuple[str, str, int, str]]: return [ # (param_name, weight_name, expert_id, shard_id) @@ -409,12 +409,12 @@ def make_expert_params_mapping( ] def weight_loader( - self, - param: torch.nn.Parameter, - loaded_weight: torch.Tensor, - weight_name: str, - shard_id: str, - expert_id: int, + self, + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + weight_name: str, + shard_id: str, + expert_id: int, ) -> None: if expert_id < self.start_expert_id or expert_id > self.end_expert_id: return @@ -441,25 +441,25 @@ def weight_loader( elif shard_id == "w1": param.data[expert_id][: self.intermediate_size, :] = loaded_weight elif shard_id == "w3": - param.data[expert_id][self.intermediate_size :, :] = loaded_weight + param.data[expert_id][self.intermediate_size:, :] = loaded_weight else: raise ValueError(f"Expected shard_id w1,w2 or w3 but got {shard_id}") def _load_fp8_scale( - self, - param: torch.nn.Parameter, - loaded_weight: torch.Tensor, - weight_name: str, - shard_id: str, - expert_id: int, + self, + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + weight_name: str, + shard_id: str, + expert_id: int, ) -> None: param_data = param.data # Input scales can be loaded directly and should be equal. if "input_scale" in weight_name: if ( - param_data[expert_id] != 1 - and (param_data[expert_id] - loaded_weight).abs() > 1e-5 + param_data[expert_id] != 1 + and (param_data[expert_id] - loaded_weight).abs() > 1e-5 ): raise ValueError( "input_scales of w1 and w3 of a layer " @@ -473,11 +473,11 @@ def _load_fp8_scale( block_n, block_k = self.block_shape[0], self.block_shape[1] if shard_id == "w1": param_data[expert_id][ - : (self.intermediate_size + block_n - 1) // block_n, : + : (self.intermediate_size + block_n - 1) // block_n, : ] = loaded_weight elif shard_id == "w3": param_data[expert_id][ - (self.intermediate_size + block_n - 1) // block_n :, : + (self.intermediate_size + block_n - 1) // block_n:, : ] = loaded_weight else: # w2 param_data[expert_id] = loaded_weight @@ -497,13 +497,13 @@ def _load_fp8_scale( class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp): def create_weights( - self, - layer: torch.nn.Module, - num_experts_per_partition: int, - hidden_size: int, - intermediate_size: int, - params_dtype: torch.dtype, - **extra_weight_attrs, + self, + layer: torch.nn.Module, + num_experts_per_partition: int, + hidden_size: int, + intermediate_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, ): # Fused gate_up_proj (column parallel) w13_weight = torch.nn.Parameter( @@ -562,16 +562,16 @@ def create_weights( set_weight_attrs(w2_weight_scale, extra_weight_attrs) def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, ) -> torch.Tensor: raise NotImplementedError @@ -590,13 +590,13 @@ def __init__(self, quant_config: Fp8Config): self.block_quant = self.quant_config.weight_block_size is not None def create_weights( - self, - layer: Module, - num_experts_per_partition: int, - hidden_size: int, - intermediate_size: int, - params_dtype: torch.dtype, - **extra_weight_attrs, + self, + layer: Module, + num_experts_per_partition: int, + hidden_size: int, + intermediate_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, ): if self.quant_config.is_checkpoint_fp8_serialized: @@ -781,16 +781,16 @@ def process_weights_after_loading(self, layer: Module) -> None: return def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, ) -> torch.Tensor: raise NotImplementedError @@ -803,23 +803,23 @@ class DeepEPMoE(EPMoE): _has_printed = False def __init__( - self, - num_experts: int, - top_k: int, - hidden_size: int, - intermediate_size: int, - params_dtype: Optional[torch.dtype] = None, - renormalize: bool = True, - use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, - topk_group: Optional[int] = None, - quant_config: Optional[QuantizationConfig] = None, - tp_size: Optional[int] = None, - prefix: str = "", - correction_bias: Optional[torch.Tensor] = None, - custom_routing_function: Optional[Callable] = None, - activation: str = "silu", - deepep_mode: DeepEPMode = DeepEPMode.auto, + self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None, + tp_size: Optional[int] = None, + prefix: str = "", + correction_bias: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + activation: str = "silu", + deepep_mode: DeepEPMode = DeepEPMode.auto, ): super().__init__( num_experts, @@ -855,13 +855,13 @@ def __init__( ) def forward( - self, - hidden_states: torch.Tensor, - reorder_topk_ids: torch.Tensor, - seg_indptr: torch.Tensor, - masked_m: torch.Tensor, - expected_m: int, - forward_mode: ForwardMode, + self, + hidden_states: torch.Tensor, + reorder_topk_ids: torch.Tensor, + seg_indptr: torch.Tensor, + masked_m: torch.Tensor, + expected_m: int, + forward_mode: ForwardMode, ): resolved_deepep_mode = self.deepep_mode.resolve(forward_mode) if resolved_deepep_mode == DeepEPMode.normal: @@ -872,10 +872,10 @@ def forward( raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}") def forward_normal( - self, - hidden_states: DisposibleBox, - reorder_topk_ids: torch.Tensor, - seg_indptr: torch.Tensor, + self, + hidden_states: DisposibleTensor, + reorder_topk_ids: torch.Tensor, + seg_indptr: torch.Tensor, ): assert self.quant_method is not None assert self.activation == "silu" @@ -997,15 +997,15 @@ def forward_normal( return down_output def forward_deepgemm_masked( - self, - hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor], - masked_m: torch.Tensor, - expected_m: int, + self, + hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor], + masked_m: torch.Tensor, + expected_m: int, ): assert self.quant_method is not None assert self.activation == "silu" assert ( - hidden_states_fp8[0].size(0) % 4 == 0 + hidden_states_fp8[0].size(0) % 4 == 0 ), f"TMA alignment error: {hidden_states_fp8[0].size(0)}" # GroupGemm-0 diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index 3880dca2c3d..9345f92e271 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -1,4 +1,4 @@ -from sglang.srt.utils import DeepEPMode, DisposibleBox +from sglang.srt.utils import DeepEPMode, DisposibleTensor try: from deep_ep import Buffer @@ -27,7 +27,6 @@ class DeepEPDispatchMode(IntEnum): class DeepEPBuffer: - _buffer = None _dispatch_mode: Optional[DeepEPDispatchMode] = None _hidden_size: Optional[int] = None @@ -36,13 +35,13 @@ class DeepEPBuffer: @classmethod def get_deepep_buffer( - cls, - group: dist.ProcessGroup, - hidden_size: int, - param_bytes: int, - deepep_mode: DeepEPMode, - num_max_dispatch_tokens_per_rank: int = None, - num_experts: int = None, + cls, + group: dist.ProcessGroup, + hidden_size: int, + param_bytes: int, + deepep_mode: DeepEPMode, + num_max_dispatch_tokens_per_rank: int = None, + num_experts: int = None, ): if cls._buffer is not None: return cls._buffer @@ -55,8 +54,8 @@ def get_deepep_buffer( if deepep_mode.enable_normal(): hidden_bytes = hidden_size * param_bytes for config in ( - Buffer.get_dispatch_config(group.size()), - Buffer.get_combine_config(group.size()), + Buffer.get_dispatch_config(group.size()), + Buffer.get_combine_config(group.size()), ): num_nvl_bytes = max( config.get_nvl_buffer_size_hint(hidden_bytes, group.size()), @@ -113,15 +112,15 @@ def set_dispatch_mode_as_low_latency(cls): class _DeepEPDispatcherImplBase: def __init__( - self, - group: torch.distributed.ProcessGroup, - router_topk: int, - permute_fusion: bool, - num_experts: int, - num_local_experts: int, - hidden_size: int, - params_dtype: torch.dtype, - deepep_mode: DeepEPMode, + self, + group: torch.distributed.ProcessGroup, + router_topk: int, + permute_fusion: bool, + num_experts: int, + num_local_experts: int, + hidden_size: int, + params_dtype: torch.dtype, + deepep_mode: DeepEPMode, ): if not use_deepep: raise ImportError( @@ -144,10 +143,10 @@ def __init__( self.handle = None def dispatch_a( - self, - hidden_states: torch.Tensor, - topk_idx: torch.Tensor, - topk_weights: torch.Tensor, + self, + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, ): raise NotImplementedError @@ -155,10 +154,10 @@ def dispatch_b(self, *args, **kwargs): raise NotImplementedError def combine_a( - self, - hidden_states: torch.Tensor, - topk_idx: torch.Tensor, - topk_weights: torch.Tensor, + self, + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, ): raise NotImplementedError @@ -177,10 +176,10 @@ def __init__(self, async_finish: bool, **kwargs): self.src2dst = None def dispatch_a( - self, - hidden_states: torch.Tensor, - topk_idx: torch.Tensor, - topk_weights: torch.Tensor, + self, + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, ): topk_idx = topk_idx.to(torch.int64) previous_event = Buffer.capture() if self.async_finish else None @@ -205,7 +204,7 @@ def dispatch_b(self, hidden_states, topk_idx, topk_weights, previous_event): seg_indptr = torch.zeros( (self.num_experts + 1,), device=hidden_states.device, dtype=torch.int64 ) - hidden_states = DisposibleBox(hidden_states) + hidden_states = DisposibleTensor(hidden_states) masked_m = expected_m = None @@ -220,11 +219,11 @@ def dispatch_b(self, hidden_states, topk_idx, topk_weights, previous_event): ) def _dispatch_core( - self, - x: torch.Tensor, - topk_idx: torch.Tensor, - topk_weights: torch.Tensor, - previous_event, + self, + x: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + previous_event, ): buffer = self._get_buffer() ( @@ -273,12 +272,12 @@ def _dispatch_core( ) def _deepep_permute( - self, - hidden_states: torch.Tensor, - topk_idx: torch.Tensor, - fp8_dtype: Optional[torch.dtype] = None, - use_fp8_w8a8: bool = False, - use_block_quant: bool = False, + self, + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + fp8_dtype: Optional[torch.dtype] = None, + use_fp8_w8a8: bool = False, + use_block_quant: bool = False, ): """ Copy from Megatron-Core token_dispatcher MoEFlexTokenDispatcher @@ -309,13 +308,13 @@ def _deepep_permute( hidden_states.shape[1], BLOCK_SIZE=512, ) - return reorder_topk_ids, seg_indptr, DisposibleBox(gateup_input) + return reorder_topk_ids, seg_indptr, DisposibleTensor(gateup_input) def combine_a( - self, - hidden_states: torch.Tensor, - topk_idx: torch.Tensor, - topk_weights: torch.Tensor, + self, + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, ): if hidden_states.shape[0] > 0: num_tokens = self.src2dst.shape[0] // self.router_topk @@ -384,17 +383,17 @@ def __init__(self, return_recv_hook: bool, **kwargs): self.return_recv_hook = return_recv_hook def dispatch_a( - self, - hidden_states: torch.Tensor, - topk_idx: torch.Tensor, - topk_weights: torch.Tensor, + self, + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, ): buffer = self._get_buffer() topk_idx = topk_idx.to(torch.int64) expected_m = ( - hidden_states.shape[0] * buffer.group_size * topk_idx.shape[1] - + self.num_experts - ) // self.num_experts + hidden_states.shape[0] * buffer.group_size * topk_idx.shape[1] + + self.num_experts + ) // self.num_experts hidden_states, masked_m, event, hook = self._dispatch_core( hidden_states, topk_idx, @@ -411,14 +410,14 @@ def dispatch_a( ) def dispatch_b( - self, - hidden_states, - topk_idx, - topk_weights, - masked_m, - expected_m, - event, - hook, + self, + hidden_states, + topk_idx, + topk_weights, + masked_m, + expected_m, + event, + hook, ): hook() if self.return_recv_hook else event.current_stream_wait() @@ -435,10 +434,10 @@ def dispatch_b( ) def _dispatch_core( - self, - hidden_states: torch.Tensor, - topk_idx: torch.Tensor, - use_fp8: bool = False, + self, + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + use_fp8: bool = False, ): """ # For H20, there will be an CUDA error: DeepEP/csrc/kernels/internode_ll.cu:337 'too many blocks in cooperative launch'. @@ -487,10 +486,10 @@ def _dispatch_core( return packed_recv_hidden, packed_recv_count, event, hook def combine_a( - self, - hidden_states: torch.Tensor, - topk_idx: torch.Tensor, - topk_weights: torch.Tensor, + self, + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, ): hidden_states, event, hook = self._combine_core( hidden_states, @@ -504,10 +503,10 @@ def combine_b(self, hidden_states, event, hook): return hidden_states def _combine_core( - self, - hidden_states: torch.Tensor, - topk_idx: torch.Tensor, - topk_weights: torch.Tensor, + self, + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, ): buffer = self._get_buffer() combined_hidden_states, event, hook = buffer.low_latency_combine( @@ -535,17 +534,17 @@ def _get_buffer(self): class DeepEPDispatcher: def __init__( - self, - group: torch.distributed.ProcessGroup, - router_topk: int, - permute_fusion: bool = False, - num_experts: int = None, - num_local_experts: int = None, - hidden_size: int = None, - params_dtype: torch.dtype = None, - deepep_mode: DeepEPMode = DeepEPMode.auto, - async_finish: bool = False, - return_recv_hook: bool = False, + self, + group: torch.distributed.ProcessGroup, + router_topk: int, + permute_fusion: bool = False, + num_experts: int = None, + num_local_experts: int = None, + hidden_size: int = None, + params_dtype: torch.dtype = None, + deepep_mode: DeepEPMode = DeepEPMode.auto, + async_finish: bool = False, + return_recv_hook: bool = False, ): self.deepep_mode = deepep_mode @@ -576,11 +575,11 @@ def dispatch(self, *args, **kwargs) -> Tuple: return self.dispatch_b() def dispatch_a( - self, - hidden_states: torch.Tensor, - topk_idx: torch.Tensor, - topk_weights: torch.Tensor, - forward_mode: ForwardMode = None, + self, + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + forward_mode: ForwardMode = None, ): inner_state = self._get_impl(forward_mode).dispatch_a( hidden_states=hidden_states, @@ -599,11 +598,11 @@ def combine(self, *args, **kwargs) -> Tuple: return self.combine_b() def combine_a( - self, - hidden_states: torch.Tensor, - topk_idx: torch.Tensor, - topk_weights: torch.Tensor, - forward_mode: ForwardMode, + self, + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + forward_mode: ForwardMode, ): inner_state = self._get_impl(forward_mode).combine_a( hidden_states=hidden_states, From a9866309faed0d9281593b4026418d14e68eab0f Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 12 Apr 2025 08:06:53 +0800 Subject: [PATCH 0623/1089] fmt --- python/sglang/srt/layers/moe/ep_moe/layer.py | 236 +++++++++--------- .../srt/layers/moe/ep_moe/token_dispatcher.py | 186 +++++++------- python/sglang/srt/utils.py | 28 +-- 3 files changed, 225 insertions(+), 225 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 65b4618911d..24d2fbce861 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -81,18 +81,18 @@ def _init_flashinfer_wrapper(cls, device): # c = a * b def forward( - self, - a: torch.Tensor, - b: torch.Tensor, - c: torch.Tensor, - batch_size: int, - weight_column_major: bool, - seg_indptr: Optional[torch.Tensor] = None, - weight_indices: Optional[torch.Tensor] = None, - use_fp8_w8a8: bool = False, - scale_a: torch.Tensor = None, - scale_b: torch.Tensor = None, - block_shape: Optional[List[int]] = None, + self, + a: torch.Tensor, + b: torch.Tensor, + c: torch.Tensor, + batch_size: int, + weight_column_major: bool, + seg_indptr: Optional[torch.Tensor] = None, + weight_indices: Optional[torch.Tensor] = None, + use_fp8_w8a8: bool = False, + scale_a: torch.Tensor = None, + scale_b: torch.Tensor = None, + block_shape: Optional[List[int]] = None, ): if self.use_flashinfer: # TODO: flashinfer @@ -132,22 +132,22 @@ class EPMoE(torch.nn.Module): """ def __init__( - self, - num_experts: int, - top_k: int, - hidden_size: int, - intermediate_size: int, - params_dtype: Optional[torch.dtype] = None, - renormalize: bool = True, - use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, - topk_group: Optional[int] = None, - quant_config: Optional[QuantizationConfig] = None, - tp_size: Optional[int] = None, - prefix: str = "", - correction_bias: Optional[torch.Tensor] = None, - custom_routing_function: Optional[Callable] = None, - activation: str = "silu", + self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None, + tp_size: Optional[int] = None, + prefix: str = "", + correction_bias: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + activation: str = "silu", ): super().__init__() @@ -264,7 +264,7 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): BLOCK_SIZE=512, ) - seg_indptr_cur_rank = seg_indptr[self.start_expert_id: self.end_expert_id + 2] + seg_indptr_cur_rank = seg_indptr[self.start_expert_id : self.end_expert_id + 2] weight_indices_cur_rank = torch.arange( 0, self.num_experts_per_partition, @@ -382,11 +382,11 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): @classmethod def make_expert_params_mapping( - cls, - ckpt_gate_proj_name: str, - ckpt_down_proj_name: str, - ckpt_up_proj_name: str, - num_experts: int, + cls, + ckpt_gate_proj_name: str, + ckpt_down_proj_name: str, + ckpt_up_proj_name: str, + num_experts: int, ) -> List[Tuple[str, str, int, str]]: return [ # (param_name, weight_name, expert_id, shard_id) @@ -409,12 +409,12 @@ def make_expert_params_mapping( ] def weight_loader( - self, - param: torch.nn.Parameter, - loaded_weight: torch.Tensor, - weight_name: str, - shard_id: str, - expert_id: int, + self, + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + weight_name: str, + shard_id: str, + expert_id: int, ) -> None: if expert_id < self.start_expert_id or expert_id > self.end_expert_id: return @@ -441,25 +441,25 @@ def weight_loader( elif shard_id == "w1": param.data[expert_id][: self.intermediate_size, :] = loaded_weight elif shard_id == "w3": - param.data[expert_id][self.intermediate_size:, :] = loaded_weight + param.data[expert_id][self.intermediate_size :, :] = loaded_weight else: raise ValueError(f"Expected shard_id w1,w2 or w3 but got {shard_id}") def _load_fp8_scale( - self, - param: torch.nn.Parameter, - loaded_weight: torch.Tensor, - weight_name: str, - shard_id: str, - expert_id: int, + self, + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + weight_name: str, + shard_id: str, + expert_id: int, ) -> None: param_data = param.data # Input scales can be loaded directly and should be equal. if "input_scale" in weight_name: if ( - param_data[expert_id] != 1 - and (param_data[expert_id] - loaded_weight).abs() > 1e-5 + param_data[expert_id] != 1 + and (param_data[expert_id] - loaded_weight).abs() > 1e-5 ): raise ValueError( "input_scales of w1 and w3 of a layer " @@ -473,11 +473,11 @@ def _load_fp8_scale( block_n, block_k = self.block_shape[0], self.block_shape[1] if shard_id == "w1": param_data[expert_id][ - : (self.intermediate_size + block_n - 1) // block_n, : + : (self.intermediate_size + block_n - 1) // block_n, : ] = loaded_weight elif shard_id == "w3": param_data[expert_id][ - (self.intermediate_size + block_n - 1) // block_n:, : + (self.intermediate_size + block_n - 1) // block_n :, : ] = loaded_weight else: # w2 param_data[expert_id] = loaded_weight @@ -497,13 +497,13 @@ def _load_fp8_scale( class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp): def create_weights( - self, - layer: torch.nn.Module, - num_experts_per_partition: int, - hidden_size: int, - intermediate_size: int, - params_dtype: torch.dtype, - **extra_weight_attrs, + self, + layer: torch.nn.Module, + num_experts_per_partition: int, + hidden_size: int, + intermediate_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, ): # Fused gate_up_proj (column parallel) w13_weight = torch.nn.Parameter( @@ -562,16 +562,16 @@ def create_weights( set_weight_attrs(w2_weight_scale, extra_weight_attrs) def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, ) -> torch.Tensor: raise NotImplementedError @@ -590,13 +590,13 @@ def __init__(self, quant_config: Fp8Config): self.block_quant = self.quant_config.weight_block_size is not None def create_weights( - self, - layer: Module, - num_experts_per_partition: int, - hidden_size: int, - intermediate_size: int, - params_dtype: torch.dtype, - **extra_weight_attrs, + self, + layer: Module, + num_experts_per_partition: int, + hidden_size: int, + intermediate_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, ): if self.quant_config.is_checkpoint_fp8_serialized: @@ -781,16 +781,16 @@ def process_weights_after_loading(self, layer: Module) -> None: return def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, ) -> torch.Tensor: raise NotImplementedError @@ -803,23 +803,23 @@ class DeepEPMoE(EPMoE): _has_printed = False def __init__( - self, - num_experts: int, - top_k: int, - hidden_size: int, - intermediate_size: int, - params_dtype: Optional[torch.dtype] = None, - renormalize: bool = True, - use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, - topk_group: Optional[int] = None, - quant_config: Optional[QuantizationConfig] = None, - tp_size: Optional[int] = None, - prefix: str = "", - correction_bias: Optional[torch.Tensor] = None, - custom_routing_function: Optional[Callable] = None, - activation: str = "silu", - deepep_mode: DeepEPMode = DeepEPMode.auto, + self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None, + tp_size: Optional[int] = None, + prefix: str = "", + correction_bias: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + activation: str = "silu", + deepep_mode: DeepEPMode = DeepEPMode.auto, ): super().__init__( num_experts, @@ -855,13 +855,13 @@ def __init__( ) def forward( - self, - hidden_states: torch.Tensor, - reorder_topk_ids: torch.Tensor, - seg_indptr: torch.Tensor, - masked_m: torch.Tensor, - expected_m: int, - forward_mode: ForwardMode, + self, + hidden_states: torch.Tensor, + reorder_topk_ids: torch.Tensor, + seg_indptr: torch.Tensor, + masked_m: torch.Tensor, + expected_m: int, + forward_mode: ForwardMode, ): resolved_deepep_mode = self.deepep_mode.resolve(forward_mode) if resolved_deepep_mode == DeepEPMode.normal: @@ -872,10 +872,10 @@ def forward( raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}") def forward_normal( - self, - hidden_states: DisposibleTensor, - reorder_topk_ids: torch.Tensor, - seg_indptr: torch.Tensor, + self, + hidden_states: DisposibleTensor, + reorder_topk_ids: torch.Tensor, + seg_indptr: torch.Tensor, ): assert self.quant_method is not None assert self.activation == "silu" @@ -997,15 +997,15 @@ def forward_normal( return down_output def forward_deepgemm_masked( - self, - hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor], - masked_m: torch.Tensor, - expected_m: int, + self, + hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor], + masked_m: torch.Tensor, + expected_m: int, ): assert self.quant_method is not None assert self.activation == "silu" assert ( - hidden_states_fp8[0].size(0) % 4 == 0 + hidden_states_fp8[0].size(0) % 4 == 0 ), f"TMA alignment error: {hidden_states_fp8[0].size(0)}" # GroupGemm-0 diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index 9345f92e271..3aad2ee1978 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -35,13 +35,13 @@ class DeepEPBuffer: @classmethod def get_deepep_buffer( - cls, - group: dist.ProcessGroup, - hidden_size: int, - param_bytes: int, - deepep_mode: DeepEPMode, - num_max_dispatch_tokens_per_rank: int = None, - num_experts: int = None, + cls, + group: dist.ProcessGroup, + hidden_size: int, + param_bytes: int, + deepep_mode: DeepEPMode, + num_max_dispatch_tokens_per_rank: int = None, + num_experts: int = None, ): if cls._buffer is not None: return cls._buffer @@ -54,8 +54,8 @@ def get_deepep_buffer( if deepep_mode.enable_normal(): hidden_bytes = hidden_size * param_bytes for config in ( - Buffer.get_dispatch_config(group.size()), - Buffer.get_combine_config(group.size()), + Buffer.get_dispatch_config(group.size()), + Buffer.get_combine_config(group.size()), ): num_nvl_bytes = max( config.get_nvl_buffer_size_hint(hidden_bytes, group.size()), @@ -112,15 +112,15 @@ def set_dispatch_mode_as_low_latency(cls): class _DeepEPDispatcherImplBase: def __init__( - self, - group: torch.distributed.ProcessGroup, - router_topk: int, - permute_fusion: bool, - num_experts: int, - num_local_experts: int, - hidden_size: int, - params_dtype: torch.dtype, - deepep_mode: DeepEPMode, + self, + group: torch.distributed.ProcessGroup, + router_topk: int, + permute_fusion: bool, + num_experts: int, + num_local_experts: int, + hidden_size: int, + params_dtype: torch.dtype, + deepep_mode: DeepEPMode, ): if not use_deepep: raise ImportError( @@ -143,10 +143,10 @@ def __init__( self.handle = None def dispatch_a( - self, - hidden_states: torch.Tensor, - topk_idx: torch.Tensor, - topk_weights: torch.Tensor, + self, + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, ): raise NotImplementedError @@ -154,10 +154,10 @@ def dispatch_b(self, *args, **kwargs): raise NotImplementedError def combine_a( - self, - hidden_states: torch.Tensor, - topk_idx: torch.Tensor, - topk_weights: torch.Tensor, + self, + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, ): raise NotImplementedError @@ -176,10 +176,10 @@ def __init__(self, async_finish: bool, **kwargs): self.src2dst = None def dispatch_a( - self, - hidden_states: torch.Tensor, - topk_idx: torch.Tensor, - topk_weights: torch.Tensor, + self, + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, ): topk_idx = topk_idx.to(torch.int64) previous_event = Buffer.capture() if self.async_finish else None @@ -219,11 +219,11 @@ def dispatch_b(self, hidden_states, topk_idx, topk_weights, previous_event): ) def _dispatch_core( - self, - x: torch.Tensor, - topk_idx: torch.Tensor, - topk_weights: torch.Tensor, - previous_event, + self, + x: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + previous_event, ): buffer = self._get_buffer() ( @@ -272,12 +272,12 @@ def _dispatch_core( ) def _deepep_permute( - self, - hidden_states: torch.Tensor, - topk_idx: torch.Tensor, - fp8_dtype: Optional[torch.dtype] = None, - use_fp8_w8a8: bool = False, - use_block_quant: bool = False, + self, + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + fp8_dtype: Optional[torch.dtype] = None, + use_fp8_w8a8: bool = False, + use_block_quant: bool = False, ): """ Copy from Megatron-Core token_dispatcher MoEFlexTokenDispatcher @@ -311,10 +311,10 @@ def _deepep_permute( return reorder_topk_ids, seg_indptr, DisposibleTensor(gateup_input) def combine_a( - self, - hidden_states: torch.Tensor, - topk_idx: torch.Tensor, - topk_weights: torch.Tensor, + self, + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, ): if hidden_states.shape[0] > 0: num_tokens = self.src2dst.shape[0] // self.router_topk @@ -383,17 +383,17 @@ def __init__(self, return_recv_hook: bool, **kwargs): self.return_recv_hook = return_recv_hook def dispatch_a( - self, - hidden_states: torch.Tensor, - topk_idx: torch.Tensor, - topk_weights: torch.Tensor, + self, + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, ): buffer = self._get_buffer() topk_idx = topk_idx.to(torch.int64) expected_m = ( - hidden_states.shape[0] * buffer.group_size * topk_idx.shape[1] - + self.num_experts - ) // self.num_experts + hidden_states.shape[0] * buffer.group_size * topk_idx.shape[1] + + self.num_experts + ) // self.num_experts hidden_states, masked_m, event, hook = self._dispatch_core( hidden_states, topk_idx, @@ -410,14 +410,14 @@ def dispatch_a( ) def dispatch_b( - self, - hidden_states, - topk_idx, - topk_weights, - masked_m, - expected_m, - event, - hook, + self, + hidden_states, + topk_idx, + topk_weights, + masked_m, + expected_m, + event, + hook, ): hook() if self.return_recv_hook else event.current_stream_wait() @@ -434,10 +434,10 @@ def dispatch_b( ) def _dispatch_core( - self, - hidden_states: torch.Tensor, - topk_idx: torch.Tensor, - use_fp8: bool = False, + self, + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + use_fp8: bool = False, ): """ # For H20, there will be an CUDA error: DeepEP/csrc/kernels/internode_ll.cu:337 'too many blocks in cooperative launch'. @@ -486,10 +486,10 @@ def _dispatch_core( return packed_recv_hidden, packed_recv_count, event, hook def combine_a( - self, - hidden_states: torch.Tensor, - topk_idx: torch.Tensor, - topk_weights: torch.Tensor, + self, + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, ): hidden_states, event, hook = self._combine_core( hidden_states, @@ -503,10 +503,10 @@ def combine_b(self, hidden_states, event, hook): return hidden_states def _combine_core( - self, - hidden_states: torch.Tensor, - topk_idx: torch.Tensor, - topk_weights: torch.Tensor, + self, + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, ): buffer = self._get_buffer() combined_hidden_states, event, hook = buffer.low_latency_combine( @@ -534,17 +534,17 @@ def _get_buffer(self): class DeepEPDispatcher: def __init__( - self, - group: torch.distributed.ProcessGroup, - router_topk: int, - permute_fusion: bool = False, - num_experts: int = None, - num_local_experts: int = None, - hidden_size: int = None, - params_dtype: torch.dtype = None, - deepep_mode: DeepEPMode = DeepEPMode.auto, - async_finish: bool = False, - return_recv_hook: bool = False, + self, + group: torch.distributed.ProcessGroup, + router_topk: int, + permute_fusion: bool = False, + num_experts: int = None, + num_local_experts: int = None, + hidden_size: int = None, + params_dtype: torch.dtype = None, + deepep_mode: DeepEPMode = DeepEPMode.auto, + async_finish: bool = False, + return_recv_hook: bool = False, ): self.deepep_mode = deepep_mode @@ -575,11 +575,11 @@ def dispatch(self, *args, **kwargs) -> Tuple: return self.dispatch_b() def dispatch_a( - self, - hidden_states: torch.Tensor, - topk_idx: torch.Tensor, - topk_weights: torch.Tensor, - forward_mode: ForwardMode = None, + self, + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + forward_mode: ForwardMode = None, ): inner_state = self._get_impl(forward_mode).dispatch_a( hidden_states=hidden_states, @@ -598,11 +598,11 @@ def combine(self, *args, **kwargs) -> Tuple: return self.combine_b() def combine_a( - self, - hidden_states: torch.Tensor, - topk_idx: torch.Tensor, - topk_weights: torch.Tensor, - forward_mode: ForwardMode, + self, + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + forward_mode: ForwardMode, ): inner_state = self._get_impl(forward_mode).combine_a( hidden_states=hidden_states, diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index e209b8fb36f..1671aef5bcf 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -54,10 +54,10 @@ import torch.distributed as dist import triton import zmq -from PIL import Image from decord import VideoReader, cpu from fastapi.responses import ORJSONResponse from packaging import version as pkg_version +from PIL import Image from starlette.routing import Mount from torch import nn from torch.func import functional_call @@ -914,10 +914,10 @@ def get_zmq_socket( context: zmq.Context, socket_type: zmq.SocketType, endpoint: str, bind: bool ): mem = psutil.virtual_memory() - total_mem = mem.total / 1024 ** 3 - available_mem = mem.available / 1024 ** 3 + total_mem = mem.total / 1024**3 + available_mem = mem.available / 1024**3 if total_mem > 32 and available_mem > 16: - buf_size = int(0.5 * 1024 ** 3) + buf_size = int(0.5 * 1024**3) else: buf_size = -1 @@ -1437,10 +1437,10 @@ def dataclass_to_string_truncated( return ( "{" + ", ".join( - f"'{k}': {dataclass_to_string_truncated(v, max_length)}" - for k, v in data.items() - if k not in skip_names - ) + f"'{k}': {dataclass_to_string_truncated(v, max_length)}" + for k, v in data.items() + if k not in skip_names + ) + "}" ) elif dataclasses.is_dataclass(data): @@ -1448,10 +1448,10 @@ def dataclass_to_string_truncated( return ( f"{data.__class__.__name__}(" + ", ".join( - f"{f.name}={dataclass_to_string_truncated(getattr(data, f.name), max_length)}" - for f in fields - if f.name not in skip_names - ) + f"{f.name}={dataclass_to_string_truncated(getattr(data, f.name), max_length)}" + for f in fields + if f.name not in skip_names + ) + ")" ) else: @@ -1641,7 +1641,7 @@ def configure_ipv6(dist_init_addr): port_str = None if len(addr) > end + 1: if addr[end + 1] == ":": - port_str = addr[end + 2:] + port_str = addr[end + 2 :] else: raise ValueError("received IPv6 address format: expected ':' after ']'") @@ -1779,7 +1779,7 @@ def retry( if not should_retry(e): raise Exception(f"retry() observe errors that should not be retried.") - delay = min(initial_delay * (2 ** try_index), max_delay) * ( + delay = min(initial_delay * (2**try_index), max_delay) * ( 0.75 + 0.25 * random.random() ) From 70d0e75edc9fc81198d791b2298fe0bddb1474b2 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 12 Apr 2025 08:07:40 +0800 Subject: [PATCH 0624/1089] more --- python/sglang/srt/utils.py | 38 ++++++++++++++++++++++---------------- 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 1671aef5bcf..b383dd7340c 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -54,10 +54,10 @@ import torch.distributed as dist import triton import zmq +from PIL import Image from decord import VideoReader, cpu from fastapi.responses import ORJSONResponse from packaging import version as pkg_version -from PIL import Image from starlette.routing import Mount from torch import nn from torch.func import functional_call @@ -914,10 +914,10 @@ def get_zmq_socket( context: zmq.Context, socket_type: zmq.SocketType, endpoint: str, bind: bool ): mem = psutil.virtual_memory() - total_mem = mem.total / 1024**3 - available_mem = mem.available / 1024**3 + total_mem = mem.total / 1024 ** 3 + available_mem = mem.available / 1024 ** 3 if total_mem > 32 and available_mem > 16: - buf_size = int(0.5 * 1024**3) + buf_size = int(0.5 * 1024 ** 3) else: buf_size = -1 @@ -1437,10 +1437,10 @@ def dataclass_to_string_truncated( return ( "{" + ", ".join( - f"'{k}': {dataclass_to_string_truncated(v, max_length)}" - for k, v in data.items() - if k not in skip_names - ) + f"'{k}': {dataclass_to_string_truncated(v, max_length)}" + for k, v in data.items() + if k not in skip_names + ) + "}" ) elif dataclasses.is_dataclass(data): @@ -1448,10 +1448,10 @@ def dataclass_to_string_truncated( return ( f"{data.__class__.__name__}(" + ", ".join( - f"{f.name}={dataclass_to_string_truncated(getattr(data, f.name), max_length)}" - for f in fields - if f.name not in skip_names - ) + f"{f.name}={dataclass_to_string_truncated(getattr(data, f.name), max_length)}" + for f in fields + if f.name not in skip_names + ) + ")" ) else: @@ -1641,7 +1641,7 @@ def configure_ipv6(dist_init_addr): port_str = None if len(addr) > end + 1: if addr[end + 1] == ":": - port_str = addr[end + 2 :] + port_str = addr[end + 2:] else: raise ValueError("received IPv6 address format: expected ':' after ']'") @@ -1779,7 +1779,7 @@ def retry( if not should_retry(e): raise Exception(f"retry() observe errors that should not be retried.") - delay = min(initial_delay * (2**try_index), max_delay) * ( + delay = min(initial_delay * (2 ** try_index), max_delay) * ( 0.75 + 0.25 * random.random() ) @@ -1882,12 +1882,15 @@ def is_disposed(self): return self._value is None @staticmethod - def maybe_unwrap(value: Union[torch.Tensor, "DisposibleTensor"]) -> torch.Tensor: + def maybe_unwrap(value: "MaybeDisposibleTensor") -> torch.Tensor: if isinstance(value, DisposibleTensor): return value.value return value +MaybeDisposibleTensor = Union[torch.Tensor, DisposibleTensor] + + class TensorCreator: def __init__(self, creator: Callable[[], torch.Tensor]): self._creator = creator @@ -1899,7 +1902,10 @@ def create(self): return value @staticmethod - def maybe_create(value: Union[torch.Tensor, "TensorCreator"]) -> torch.Tensor: + def maybe_create(value: "MaybeTensorCreator") -> torch.Tensor: if isinstance(value, TensorCreator): return value.create() return value + + +MaybeTensorCreator = Union[torch.Tensor, TensorCreator] From ade9523f95f7f9f46dc2a1a786fb45c13808830e Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 12 Apr 2025 08:08:59 +0800 Subject: [PATCH 0625/1089] more --- python/sglang/srt/utils.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index b383dd7340c..22257001a61 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -1864,16 +1864,24 @@ def fast_topk(values, topk, dim): return torch.topk(values, topk, dim=dim) +@dataclasses.dataclass +class _TensorMetadata: + shape: Any + device: Any + dtype: Any + + class DisposibleTensor: def __init__(self, value: torch.Tensor): self._value = value + self._backuped_metadata: Optional[_TensorMetadata] = None @property def value(self): assert not self.is_disposed return self._value - def dispose(self): + def dispose(self, backup_metadata: bool): assert not self.is_disposed self._value = None @@ -1887,6 +1895,18 @@ def maybe_unwrap(value: "MaybeDisposibleTensor") -> torch.Tensor: return value.value return value + @property + def shape(self): + return TODO + + @property + def device(self): + return TODO + + @property + def dtype(self): + return TODO + MaybeDisposibleTensor = Union[torch.Tensor, DisposibleTensor] From b9cd9de52229b0c7645ce309d0faf27b57144341 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 12 Apr 2025 08:09:56 +0800 Subject: [PATCH 0626/1089] more --- python/sglang/srt/utils.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 22257001a61..76a43873a7e 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -1864,17 +1864,11 @@ def fast_topk(values, topk, dim): return torch.topk(values, topk, dim=dim) -@dataclasses.dataclass -class _TensorMetadata: - shape: Any - device: Any - dtype: Any - class DisposibleTensor: def __init__(self, value: torch.Tensor): self._value = value - self._backuped_metadata: Optional[_TensorMetadata] = None + self._backedup_metadata: Optional[Dict[str, Any]] = None @property def value(self): @@ -1907,6 +1901,12 @@ def device(self): def dtype(self): return TODO + def _get_metadata(self, name: str): + if not self.is_disposed: + return getattr(self._value, name) + assert self._backedup_metadata is not None + return self._backedup_metadata[name] + MaybeDisposibleTensor = Union[torch.Tensor, DisposibleTensor] From 49f576108b394d0e83144b69555872b3ca131bc1 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 12 Apr 2025 08:10:18 +0800 Subject: [PATCH 0627/1089] more --- python/sglang/srt/utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 76a43873a7e..5860ace48cd 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -1864,7 +1864,6 @@ def fast_topk(values, topk, dim): return torch.topk(values, topk, dim=dim) - class DisposibleTensor: def __init__(self, value: torch.Tensor): self._value = value @@ -1904,7 +1903,7 @@ def dtype(self): def _get_metadata(self, name: str): if not self.is_disposed: return getattr(self._value, name) - assert self._backedup_metadata is not None + assert self._backedup_metadata is not None, "Use backup_metadata flag if you want to use metadata after dispose" return self._backedup_metadata[name] From e5b9b976c23d772c65de658314570661d393a694 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 12 Apr 2025 08:10:31 +0800 Subject: [PATCH 0628/1089] more --- python/sglang/srt/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 5860ace48cd..4fb28625806 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -1890,15 +1890,15 @@ def maybe_unwrap(value: "MaybeDisposibleTensor") -> torch.Tensor: @property def shape(self): - return TODO + return self._get_metadata("shape") @property def device(self): - return TODO + return self._get_metadata("device") @property def dtype(self): - return TODO + return self._get_metadata("dtype") def _get_metadata(self, name: str): if not self.is_disposed: From 73c5a43fe0ecdee6d81f771df02227fd274379ee Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 12 Apr 2025 08:10:48 +0800 Subject: [PATCH 0629/1089] more --- python/sglang/srt/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 4fb28625806..56ce66561f2 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -1874,8 +1874,10 @@ def value(self): assert not self.is_disposed return self._value - def dispose(self, backup_metadata: bool): + def dispose(self, backup_metadata: bool = True): assert not self.is_disposed + if backup_metadata: + TODO self._value = None @property From fad9b18011989fdea2170353820c634e234c398f Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 12 Apr 2025 08:10:58 +0800 Subject: [PATCH 0630/1089] more --- python/sglang/srt/utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 56ce66561f2..87a42c0bc48 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -1867,7 +1867,7 @@ def fast_topk(values, topk, dim): class DisposibleTensor: def __init__(self, value: torch.Tensor): self._value = value - self._backedup_metadata: Optional[Dict[str, Any]] = None + self._backup_metadata: Optional[Dict[str, Any]] = None @property def value(self): @@ -1877,7 +1877,7 @@ def value(self): def dispose(self, backup_metadata: bool = True): assert not self.is_disposed if backup_metadata: - TODO + self._backup_metadata = TODO self._value = None @property @@ -1905,8 +1905,8 @@ def dtype(self): def _get_metadata(self, name: str): if not self.is_disposed: return getattr(self._value, name) - assert self._backedup_metadata is not None, "Use backup_metadata flag if you want to use metadata after dispose" - return self._backedup_metadata[name] + assert self._backup_metadata is not None, "Use backup_metadata flag if you want to use metadata after dispose" + return self._backup_metadata[name] MaybeDisposibleTensor = Union[torch.Tensor, DisposibleTensor] From b914ff3d05f9c2dea2a6a23a80ffde50f2d04fc8 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 12 Apr 2025 08:11:43 +0800 Subject: [PATCH 0631/1089] more --- python/sglang/srt/utils.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 87a42c0bc48..374886af580 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -1877,7 +1877,7 @@ def value(self): def dispose(self, backup_metadata: bool = True): assert not self.is_disposed if backup_metadata: - self._backup_metadata = TODO + self._backup_metadata = self._compute_backup_metadata(self._value) self._value = None @property @@ -1908,6 +1908,12 @@ def _get_metadata(self, name: str): assert self._backup_metadata is not None, "Use backup_metadata flag if you want to use metadata after dispose" return self._backup_metadata[name] + _BACKUP_METADATA_KEYS = ["shape", "device", "dtype"] + + @staticmethod + def _compute_backup_metadata(value: torch.Tensor): + return {k: getattr(value, k) for k in DisposibleTensor._BACKUP_METADATA_KEYS} + MaybeDisposibleTensor = Union[torch.Tensor, DisposibleTensor] From f0fdbe18ca8741f60b2a5c2216187c79166b5efc Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 12 Apr 2025 08:12:48 +0800 Subject: [PATCH 0632/1089] more --- python/sglang/srt/layers/moe/ep_moe/kernels.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/kernels.py b/python/sglang/srt/layers/moe/ep_moe/kernels.py index 3f2851105c9..7fbe4ea5755 100644 --- a/python/sglang/srt/layers/moe/ep_moe/kernels.py +++ b/python/sglang/srt/layers/moe/ep_moe/kernels.py @@ -7,7 +7,7 @@ from sglang.srt.distributed import get_tensor_model_parallel_rank from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8 -from sglang.srt.utils import DisposibleTensor, is_cuda +from sglang.srt.utils import DisposibleTensor, is_cuda, TensorCreator, MaybeTensorCreator _is_cuda = is_cuda() if _is_cuda: @@ -637,7 +637,7 @@ def compute_m_num_tiles_indptr( def grouped_gemm_triton( a: torch.Tensor, b: torch.Tensor, - c: torch.Tensor, + c: MaybeTensorCreator, batch_size: int, weight_column_major: bool, seg_indptr: Optional[torch.Tensor] = None, @@ -681,8 +681,7 @@ def grouped_gemm_triton( m_num_tiles_indptr, seg_indptr, batch_size, config["BLOCK_SIZE_M"] ) - if not isinstance(c, torch.Tensor): - c = c() + c = TensorCreator.maybe_create(c) grid = lambda META: ( triton.cdiv(a.size(0), META["BLOCK_SIZE_M"]) + batch_size, From 3b1c1aa22a83ad10c28c065ebe542c4bbdf41862 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 12 Apr 2025 08:13:02 +0800 Subject: [PATCH 0633/1089] more --- python/sglang/srt/layers/moe/ep_moe/kernels.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/kernels.py b/python/sglang/srt/layers/moe/ep_moe/kernels.py index 7fbe4ea5755..794d1b41ea3 100644 --- a/python/sglang/srt/layers/moe/ep_moe/kernels.py +++ b/python/sglang/srt/layers/moe/ep_moe/kernels.py @@ -7,7 +7,7 @@ from sglang.srt.distributed import get_tensor_model_parallel_rank from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8 -from sglang.srt.utils import DisposibleTensor, is_cuda, TensorCreator, MaybeTensorCreator +from sglang.srt.utils import DisposibleTensor, is_cuda, TensorCreator, MaybeTensorCreator, MaybeDisposibleTensor _is_cuda = is_cuda() if _is_cuda: @@ -635,7 +635,7 @@ def compute_m_num_tiles_indptr( def grouped_gemm_triton( - a: torch.Tensor, + a: MaybeDisposibleTensor, b: torch.Tensor, c: MaybeTensorCreator, batch_size: int, From 915e864f32812da08ea9fd01fe4921b6d08768a4 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 12 Apr 2025 08:13:52 +0800 Subject: [PATCH 0634/1089] more --- python/sglang/srt/layers/moe/ep_moe/kernels.py | 9 +++------ python/sglang/srt/utils.py | 5 +++++ 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/kernels.py b/python/sglang/srt/layers/moe/ep_moe/kernels.py index 794d1b41ea3..9dba9a0a935 100644 --- a/python/sglang/srt/layers/moe/ep_moe/kernels.py +++ b/python/sglang/srt/layers/moe/ep_moe/kernels.py @@ -655,12 +655,9 @@ def grouped_gemm_triton( assert len(block_shape) == 2 block_n, block_k = block_shape[0], block_shape[1] if _is_cuda: - if isinstance(a, DisposibleTensor): - a_box = a - a, scale_a = sglang_per_token_group_quant_fp8(a.value, block_k) - a_box.dispose() - else: - a, scale_a = sglang_per_token_group_quant_fp8(a, block_k) + a_ref = a + a, scale_a = sglang_per_token_group_quant_fp8(DisposibleTensor.maybe_unwrap(a), block_k) + DisposibleTensor.maybe_dispose(a_ref) else: a, scale_a = per_token_group_quant_fp8(a, block_k) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 374886af580..60f1a417f82 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -1890,6 +1890,11 @@ def maybe_unwrap(value: "MaybeDisposibleTensor") -> torch.Tensor: return value.value return value + @staticmethod + def maybe_dispose(value: "MaybeDisposibleTensor") -> torch.Tensor: + if isinstance(value, DisposibleTensor): + value.dispose() + @property def shape(self): return self._get_metadata("shape") From f066678323103492a0d62a8be64be08e80cbc980 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 12 Apr 2025 08:14:17 +0800 Subject: [PATCH 0635/1089] more --- python/sglang/srt/layers/moe/ep_moe/kernels.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/kernels.py b/python/sglang/srt/layers/moe/ep_moe/kernels.py index 9dba9a0a935..d60cd5693db 100644 --- a/python/sglang/srt/layers/moe/ep_moe/kernels.py +++ b/python/sglang/srt/layers/moe/ep_moe/kernels.py @@ -4,7 +4,6 @@ import torch import triton import triton.language as tl - from sglang.srt.distributed import get_tensor_model_parallel_rank from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8 from sglang.srt.utils import DisposibleTensor, is_cuda, TensorCreator, MaybeTensorCreator, MaybeDisposibleTensor @@ -438,13 +437,13 @@ def gelu_and_mul_triton_kernel( * ( 1 + tanh( - kAlpha - * ( - gate_output - + 0.044715 * gate_output * gate_output * gate_output - ) + kAlpha + * ( + gate_output + + 0.044715 * gate_output * gate_output * gate_output ) ) + ) ) gate_output = gate_output.to(InDtype) @@ -654,12 +653,12 @@ def grouped_gemm_triton( if block_shape is not None: assert len(block_shape) == 2 block_n, block_k = block_shape[0], block_shape[1] + a_ref = a if _is_cuda: - a_ref = a a, scale_a = sglang_per_token_group_quant_fp8(DisposibleTensor.maybe_unwrap(a), block_k) - DisposibleTensor.maybe_dispose(a_ref) else: a, scale_a = per_token_group_quant_fp8(a, block_k) + DisposibleTensor.maybe_dispose(a_ref) assert triton.cdiv(a.shape[-1], block_k) == scale_a.shape[-1] assert triton.cdiv(b.shape[-2], block_n) == scale_b.shape[-2] From 7b14a6fb8373b4b98f3e416fdd37c119f3ad5843 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 12 Apr 2025 08:14:29 +0800 Subject: [PATCH 0636/1089] more --- python/sglang/srt/layers/moe/ep_moe/kernels.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/kernels.py b/python/sglang/srt/layers/moe/ep_moe/kernels.py index d60cd5693db..c4693c0b371 100644 --- a/python/sglang/srt/layers/moe/ep_moe/kernels.py +++ b/python/sglang/srt/layers/moe/ep_moe/kernels.py @@ -657,7 +657,7 @@ def grouped_gemm_triton( if _is_cuda: a, scale_a = sglang_per_token_group_quant_fp8(DisposibleTensor.maybe_unwrap(a), block_k) else: - a, scale_a = per_token_group_quant_fp8(a, block_k) + a, scale_a = per_token_group_quant_fp8(DisposibleTensor.maybe_unwrap(a), block_k) DisposibleTensor.maybe_dispose(a_ref) assert triton.cdiv(a.shape[-1], block_k) == scale_a.shape[-1] From 23e6a324676c67bc9b43bf41a4d1a3a5bc314ee0 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 12 Apr 2025 08:14:55 +0800 Subject: [PATCH 0637/1089] fmt --- .../sglang/srt/layers/moe/ep_moe/kernels.py | 27 +++++++++++----- python/sglang/srt/utils.py | 32 ++++++++++--------- 2 files changed, 36 insertions(+), 23 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/kernels.py b/python/sglang/srt/layers/moe/ep_moe/kernels.py index c4693c0b371..d59eeeaa45d 100644 --- a/python/sglang/srt/layers/moe/ep_moe/kernels.py +++ b/python/sglang/srt/layers/moe/ep_moe/kernels.py @@ -4,9 +4,16 @@ import torch import triton import triton.language as tl + from sglang.srt.distributed import get_tensor_model_parallel_rank from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8 -from sglang.srt.utils import DisposibleTensor, is_cuda, TensorCreator, MaybeTensorCreator, MaybeDisposibleTensor +from sglang.srt.utils import ( + DisposibleTensor, + MaybeDisposibleTensor, + MaybeTensorCreator, + TensorCreator, + is_cuda, +) _is_cuda = is_cuda() if _is_cuda: @@ -437,13 +444,13 @@ def gelu_and_mul_triton_kernel( * ( 1 + tanh( - kAlpha - * ( - gate_output - + 0.044715 * gate_output * gate_output * gate_output + kAlpha + * ( + gate_output + + 0.044715 * gate_output * gate_output * gate_output + ) ) ) - ) ) gate_output = gate_output.to(InDtype) @@ -655,9 +662,13 @@ def grouped_gemm_triton( block_n, block_k = block_shape[0], block_shape[1] a_ref = a if _is_cuda: - a, scale_a = sglang_per_token_group_quant_fp8(DisposibleTensor.maybe_unwrap(a), block_k) + a, scale_a = sglang_per_token_group_quant_fp8( + DisposibleTensor.maybe_unwrap(a), block_k + ) else: - a, scale_a = per_token_group_quant_fp8(DisposibleTensor.maybe_unwrap(a), block_k) + a, scale_a = per_token_group_quant_fp8( + DisposibleTensor.maybe_unwrap(a), block_k + ) DisposibleTensor.maybe_dispose(a_ref) assert triton.cdiv(a.shape[-1], block_k) == scale_a.shape[-1] diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 60f1a417f82..b96bf1afecd 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -54,10 +54,10 @@ import torch.distributed as dist import triton import zmq -from PIL import Image from decord import VideoReader, cpu from fastapi.responses import ORJSONResponse from packaging import version as pkg_version +from PIL import Image from starlette.routing import Mount from torch import nn from torch.func import functional_call @@ -914,10 +914,10 @@ def get_zmq_socket( context: zmq.Context, socket_type: zmq.SocketType, endpoint: str, bind: bool ): mem = psutil.virtual_memory() - total_mem = mem.total / 1024 ** 3 - available_mem = mem.available / 1024 ** 3 + total_mem = mem.total / 1024**3 + available_mem = mem.available / 1024**3 if total_mem > 32 and available_mem > 16: - buf_size = int(0.5 * 1024 ** 3) + buf_size = int(0.5 * 1024**3) else: buf_size = -1 @@ -1437,10 +1437,10 @@ def dataclass_to_string_truncated( return ( "{" + ", ".join( - f"'{k}': {dataclass_to_string_truncated(v, max_length)}" - for k, v in data.items() - if k not in skip_names - ) + f"'{k}': {dataclass_to_string_truncated(v, max_length)}" + for k, v in data.items() + if k not in skip_names + ) + "}" ) elif dataclasses.is_dataclass(data): @@ -1448,10 +1448,10 @@ def dataclass_to_string_truncated( return ( f"{data.__class__.__name__}(" + ", ".join( - f"{f.name}={dataclass_to_string_truncated(getattr(data, f.name), max_length)}" - for f in fields - if f.name not in skip_names - ) + f"{f.name}={dataclass_to_string_truncated(getattr(data, f.name), max_length)}" + for f in fields + if f.name not in skip_names + ) + ")" ) else: @@ -1641,7 +1641,7 @@ def configure_ipv6(dist_init_addr): port_str = None if len(addr) > end + 1: if addr[end + 1] == ":": - port_str = addr[end + 2:] + port_str = addr[end + 2 :] else: raise ValueError("received IPv6 address format: expected ':' after ']'") @@ -1779,7 +1779,7 @@ def retry( if not should_retry(e): raise Exception(f"retry() observe errors that should not be retried.") - delay = min(initial_delay * (2 ** try_index), max_delay) * ( + delay = min(initial_delay * (2**try_index), max_delay) * ( 0.75 + 0.25 * random.random() ) @@ -1910,7 +1910,9 @@ def dtype(self): def _get_metadata(self, name: str): if not self.is_disposed: return getattr(self._value, name) - assert self._backup_metadata is not None, "Use backup_metadata flag if you want to use metadata after dispose" + assert ( + self._backup_metadata is not None + ), "Use backup_metadata flag if you want to use metadata after dispose" return self._backup_metadata[name] _BACKUP_METADATA_KEYS = ["shape", "device", "dtype"] From b0348b217126645ffc00e08a6ac8b5d70664cd36 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 12 Apr 2025 08:15:45 +0800 Subject: [PATCH 0638/1089] more --- python/sglang/srt/layers/moe/ep_moe/layer.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 24d2fbce861..25e5d370d81 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -41,6 +41,7 @@ from sglang.srt.utils import ( DeepEPMode, DisposibleTensor, + TensorCreator, is_cuda, is_hip, set_weight_attrs, @@ -903,11 +904,13 @@ def forward_normal( ) # GroupGemm-0 - gateup_output_creator = lambda: torch.empty( - hidden_states_shape[0], - self.w13_weight.shape[1], - device=hidden_states_device, - dtype=hidden_states_dtype, + gateup_output_creator = TensorCreator( + lambda: torch.empty( + hidden_states_shape[0], + self.w13_weight.shape[1], + device=hidden_states_device, + dtype=hidden_states_dtype, + ) ) if hidden_states.value.shape[0] > 0: @@ -930,7 +933,7 @@ def forward_normal( block_shape=self.block_shape, ) else: - gateup_output = gateup_output_creator() + gateup_output = gateup_output_creator.create() # NOTE disposed earlier # hidden_states.dispose() From 5c3d1b566f61e16a7f5f40945cec1856c169c978 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 12 Apr 2025 08:17:30 +0800 Subject: [PATCH 0639/1089] more --- python/sglang/srt/layers/moe/ep_moe/layer.py | 29 +++++++++----------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 25e5d370d81..5a9132abbe6 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -41,6 +41,7 @@ from sglang.srt.utils import ( DeepEPMode, DisposibleTensor, + MaybeDisposibleTensor, TensorCreator, is_cuda, is_hip, @@ -874,7 +875,7 @@ def forward( def forward_normal( self, - hidden_states: DisposibleTensor, + hidden_states: MaybeDisposibleTensor, reorder_topk_ids: torch.Tensor, seg_indptr: torch.Tensor, ): @@ -882,16 +883,12 @@ def forward_normal( assert self.activation == "silu" if self.grouped_gemm_runner is None: self.grouped_gemm_runner = GroupedGemmRunner( - hidden_states.value.device, use_flashinfer=False # TODO: use flashinfer + hidden_states.device, use_flashinfer=False # TODO: use flashinfer ) - hidden_states_device = hidden_states.value.device - hidden_states_shape = hidden_states.value.shape - hidden_states_dtype = hidden_states.value.dtype - if self.activation_scheme == "dynamic" and not self.use_block_quant: max_value = ( - torch.max(hidden_states.value) + torch.max(DisposibleTensor.maybe_unwrap(hidden_states)) .repeat(self.num_experts_per_partition) .to(torch.float32) ) @@ -899,21 +896,21 @@ def forward_normal( weight_indices_cur_rank = torch.arange( 0, self.num_experts_per_partition, - device=hidden_states_device, + device=hidden_states.device, dtype=torch.int64, ) # GroupGemm-0 gateup_output_creator = TensorCreator( lambda: torch.empty( - hidden_states_shape[0], + hidden_states.shape[0], self.w13_weight.shape[1], - device=hidden_states_device, - dtype=hidden_states_dtype, + device=hidden_states.device, + dtype=hidden_states.dtype, ) ) - if hidden_states.value.shape[0] > 0: + if hidden_states.shape[0] > 0: gateup_output = self.grouped_gemm_runner( # NOTE pass in box a=hidden_states, @@ -946,14 +943,14 @@ def forward_normal( dtype=( self.fp8_dtype if (self.use_fp8_w8a8 and not self.use_block_quant) - else hidden_states_dtype + else hidden_states.dtype ), ) if self.w2_input_scale is None and not self.use_block_quant: self.w2_input_scale = torch.ones( self.num_experts_per_partition, dtype=torch.float32, - device=hidden_states_device, + device=hidden_states.device, ) if self.activation == "silu": @@ -976,8 +973,8 @@ def forward_normal( down_output = torch.empty( down_input.shape[0], self.w2_weight.shape[1], - device=hidden_states_device, - dtype=hidden_states_dtype, + device=hidden_states.device, + dtype=hidden_states.dtype, ) if down_input.shape[0] > 0: down_output = self.grouped_gemm_runner( From fb97f9507070a94997d3718caca91a5d9af98e07 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 12 Apr 2025 08:19:23 +0800 Subject: [PATCH 0640/1089] more --- python/sglang/srt/layers/moe/ep_moe/layer.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 5a9132abbe6..9e73441a821 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -932,9 +932,6 @@ def forward_normal( else: gateup_output = gateup_output_creator.create() - # NOTE disposed earlier - # hidden_states.dispose() - # Act down_input = torch.empty( gateup_output.shape[0], From d184576e555cb8987ac0dbfee03be383c175fce0 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 12 Apr 2025 08:19:54 +0800 Subject: [PATCH 0641/1089] rm --- python/sglang/srt/utils.py | 34 ---------------------------------- 1 file changed, 34 deletions(-) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index b96bf1afecd..24afb24e7e6 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -1821,40 +1821,6 @@ def resolve(self, forward_mode): return DeepEPMode.normal -T = TypeVar("T") - - -class Withable(Generic[T]): - def __init__(self): - self._value: Optional[T] = None - - @property - def value(self) -> T: - return self._value - - @contextmanager - def with_value(self, new_value: T): - assert self._value is None - self._value = new_value - try: - yield - finally: - assert self._value is new_value - self._value = None - - -@contextmanager -def configure_deep_gemm_num_sms(num_sms): - import deep_gemm - - original_num_sms = deep_gemm.get_num_sms() - deep_gemm.set_num_sms(num_sms) - try: - yield - finally: - deep_gemm.set_num_sms(original_num_sms) - - def fast_topk(values, topk, dim): if topk == 1: # Use max along the specified dimension to get both value and index From 54b0f961efc3552a9b888d1fce5da028cdee9b0a Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 12 Apr 2025 08:24:40 +0800 Subject: [PATCH 0642/1089] more --- python/sglang/srt/utils.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 24afb24e7e6..3fe1a2ed6e6 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -1842,8 +1842,14 @@ def value(self): def dispose(self, backup_metadata: bool = True): assert not self.is_disposed + + if not torch.compiler.is_compiling(): + refcount = sys.getrefcount(self._value) + assert refcount == 2, f"{refcount=}" + if backup_metadata: self._backup_metadata = self._compute_backup_metadata(self._value) + self._value = None @property From 3efc364e51804629baf73d7f2d58f6185be00443 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 12 Apr 2025 08:27:20 +0800 Subject: [PATCH 0643/1089] more --- python/sglang/srt/layers/moe/ep_moe/layer.py | 23 ++++++++++++-------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 9e73441a821..f9ec862e431 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -210,7 +210,9 @@ def __init__( self.grouped_gemm_runner = None - def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): + def forward( + self, hidden_states: MaybeDisposibleTensor, router_logits: torch.Tensor + ): assert self.quant_method is not None if self.grouped_gemm_runner is None: @@ -246,7 +248,7 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): ) if self.activation_scheme == "dynamic" and not self.use_block_quant: max_value = ( - torch.max(hidden_states) + torch.max(DisposibleTensor.maybe_unwrap(hidden_states)) .repeat(self.num_experts_per_partition) .to(torch.float32) ) @@ -274,16 +276,18 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): dtype=torch.int64, ) # GroupGemm-0 - gateup_output = torch.empty( - gateup_input.shape[0], - self.w13_weight.shape[1], - device=hidden_states.device, - dtype=hidden_states.dtype, + gateup_output_creator = TensorCreator( + lambda: torch.empty( + gateup_input.shape[0], + self.w13_weight.shape[1], + device=hidden_states.device, + dtype=hidden_states.dtype, + ) ) gateup_output = self.grouped_gemm_runner( a=gateup_input, b=self.w13_weight, - c=gateup_output, + c=gateup_output_creator, batch_size=self.num_experts_per_partition, weight_column_major=True, seg_indptr=seg_indptr_cur_rank, @@ -341,6 +345,8 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): else: raise ValueError(f"Unsupported activation: {self.activation=}") + del gateup_output + # GroupGemm-1 down_output = torch.empty( down_input.shape[0], @@ -912,7 +918,6 @@ def forward_normal( if hidden_states.shape[0] > 0: gateup_output = self.grouped_gemm_runner( - # NOTE pass in box a=hidden_states, b=self.w13_weight, c=gateup_output_creator, From 740cd8a0dc5b7c770f3eaef276f3d33389508bda Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 12 Apr 2025 08:38:39 +0800 Subject: [PATCH 0644/1089] temp --- python/sglang/bench_serving.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index f6e03a3088e..b267c7c3b6d 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -427,11 +427,11 @@ async def async_request_gserver( raise NotImplementedError() -async def async_request_profile(api_url: str) -> RequestFuncOutput: +async def async_request_profile(api_url: str, json_data=None) -> RequestFuncOutput: async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: output = RequestFuncOutput() try: - async with session.post(url=api_url) as response: + async with session.post(url=api_url, json=json_data) as response: if response.status == 200: output.success = True else: @@ -688,7 +688,6 @@ def sample_random_requests( tokenizer: PreTrainedTokenizerBase, dataset_path: str, ) -> List[Tuple[str, int, int]]: - input_lens = np.random.randint( max(int(input_len * range_ratio), 1), input_len + 1, @@ -1043,7 +1042,7 @@ async def limited_request_func(request_func_input, pbar): if profile: print("Starting profiler...") profile_output = await async_request_profile( - api_url=base_url + "/start_profile" + api_url=base_url + "/start_profile", json_data={"activities": ["MEM"]} ) if profile_output.success: print("Profiler started") From 0e2c71cb79c00cfda9815ac24d2f286171232256 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 12 Apr 2025 09:03:00 +0800 Subject: [PATCH 0645/1089] more --- python/sglang/srt/models/deepseek_v2.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 8ddc24d2b37..a0c2b1b0010 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -337,17 +337,15 @@ def forward_deepep( topk_weights, forward_mode=forward_mode, ) - final_hidden_states = ( - self.experts( - hidden_states=hidden_states, - reorder_topk_ids=reorder_topk_ids, - seg_indptr=seg_indptr, - masked_m=masked_m, - expected_m=expected_m, - forward_mode=forward_mode, - ) - * self.routed_scaling_factor + final_hidden_states = self.experts( + hidden_states=hidden_states, + reorder_topk_ids=reorder_topk_ids, + seg_indptr=seg_indptr, + masked_m=masked_m, + expected_m=expected_m, + forward_mode=forward_mode, ) + final_hidden_states *= self.routed_scaling_factor if self.ep_size > 1: final_hidden_states = self.deepep_dispatcher.combine( final_hidden_states, From 1d771ce82d9905799860ea08da3e2ee44535ace3 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 12 Apr 2025 10:45:01 +0800 Subject: [PATCH 0646/1089] Revert "temp" This reverts commit 740cd8a0dc5b7c770f3eaef276f3d33389508bda. --- python/sglang/bench_serving.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index b267c7c3b6d..f6e03a3088e 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -427,11 +427,11 @@ async def async_request_gserver( raise NotImplementedError() -async def async_request_profile(api_url: str, json_data=None) -> RequestFuncOutput: +async def async_request_profile(api_url: str) -> RequestFuncOutput: async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: output = RequestFuncOutput() try: - async with session.post(url=api_url, json=json_data) as response: + async with session.post(url=api_url) as response: if response.status == 200: output.success = True else: @@ -688,6 +688,7 @@ def sample_random_requests( tokenizer: PreTrainedTokenizerBase, dataset_path: str, ) -> List[Tuple[str, int, int]]: + input_lens = np.random.randint( max(int(input_len * range_ratio), 1), input_len + 1, @@ -1042,7 +1043,7 @@ async def limited_request_func(request_func_input, pbar): if profile: print("Starting profiler...") profile_output = await async_request_profile( - api_url=base_url + "/start_profile", json_data={"activities": ["MEM"]} + api_url=base_url + "/start_profile" ) if profile_output.success: print("Profiler started") From 12a3ba5977bac63997b879f2168230e431367f9d Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 12 Apr 2025 13:20:09 +0800 Subject: [PATCH 0647/1089] more --- python/sglang/bench_serving.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index f6e03a3088e..ac8f868d39b 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -606,7 +606,7 @@ def sample_sharegpt_requests( apply_chat_template=False, ) -> List[Tuple[str, int, int]]: if fixed_output_len is not None and fixed_output_len < 4: - raise ValueError("output_len too small") + print("Warn: output_len too small") # Download sharegpt if necessary if not os.path.isfile(dataset_path) and dataset_path == "": From 207757559e7fa5653c5e359a46485a1cb296178c Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 12 Apr 2025 13:20:27 +0800 Subject: [PATCH 0648/1089] more --- python/sglang/bench_serving.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index ac8f868d39b..d2623238a6d 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -665,7 +665,7 @@ def sample_sharegpt_requests( len(completion_token_ids) if fixed_output_len is None else fixed_output_len ) - if prompt_len < 2 or output_len < 2: + if prompt_len < 2 or ((fixed_output_len is None) and (output_len < 2)): # Prune too short sequences. continue From 0cdebf33bd72a628b3f69de02287eb60a7e8f22a Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 12 Apr 2025 13:20:49 +0800 Subject: [PATCH 0649/1089] more --- python/sglang/bench_serving.py | 39 +++++++++++++++++++--------------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index d2623238a6d..098f3355f86 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -74,7 +74,7 @@ class RequestFuncOutput: def remove_prefix(text: str, prefix: str) -> str: - return text[len(prefix) :] if text.startswith(prefix) else text + return text[len(prefix):] if text.startswith(prefix) else text def remove_suffix(text: str, suffix: str) -> str: @@ -395,8 +395,8 @@ async def async_request_sglang_generate( if num_new_tokens == 0: continue adjust_itl = ( - timestamp - most_recent_timestamp - ) / num_new_tokens + timestamp - most_recent_timestamp + ) / num_new_tokens output.itl.extend([adjust_itl] * num_new_tokens) most_recent_timestamp = timestamp @@ -688,7 +688,6 @@ def sample_random_requests( tokenizer: PreTrainedTokenizerBase, dataset_path: str, ) -> List[Tuple[str, int, int]]: - input_lens = np.random.randint( max(int(input_len * range_ratio), 1), input_len + 1, @@ -935,9 +934,9 @@ def calculate_metrics( output_throughput_retokenized=sum(retokenized_output_lens) / dur_s, total_throughput=(total_input + sum(output_lens)) / dur_s, total_throughput_retokenized=(total_input + sum(retokenized_output_lens)) - / dur_s, + / dur_s, mean_ttft_ms=np.mean(ttfts or 0) - * 1000, # ttfts is empty if streaming is not supported by backend + * 1000, # ttfts is empty if streaming is not supported by backend median_ttft_ms=np.median(ttfts or 0) * 1000, std_ttft_ms=np.std(ttfts or 0) * 1000, p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000, @@ -1464,6 +1463,12 @@ def __call__(self, parser, namespace, values, option_string=None): default=1000, help="Number of prompts to process. Default is 1000.", ) + parser.add_argument( + "--skip-num-prompts", + type=int, + default=0, + help="Number of prompts to skip. Default is 0.", + ) parser.add_argument( "--sharegpt-output-len", type=int, @@ -1493,27 +1498,27 @@ def __call__(self, parser, namespace, values, option_string=None): type=float, default=0.0, help="Range of sampled ratio of input/output length, " - "used only for random dataset.", + "used only for random dataset.", ) parser.add_argument( "--request-rate", type=float, default=float("inf"), help="Number of requests per second. If this is inf, then all the requests are sent at time 0. " - "Otherwise, we use Poisson process to synthesize the request arrival times. Default is inf.", + "Otherwise, we use Poisson process to synthesize the request arrival times. Default is inf.", ) parser.add_argument( "--max-concurrency", type=int, default=None, help="Maximum number of concurrent requests. This can be used " - "to help simulate an environment where a higher level component " - "is enforcing a maximum number of concurrent requests. While the " - "--request-rate argument controls the rate at which requests are " - "initiated, this argument will control how many are actually allowed " - "to execute at a time. This means that when used in combination, the " - "actual request rate may be lower than specified with --request-rate, " - "if the server is not processing requests fast enough to keep up.", + "to help simulate an environment where a higher level component " + "is enforcing a maximum number of concurrent requests. While the " + "--request-rate argument controls the rate at which requests are " + "initiated, this argument will control how many are actually allowed " + "to execute at a time. This means that when used in combination, the " + "actual request rate may be lower than specified with --request-rate, " + "if the server is not processing requests fast enough to keep up.", ) parser.add_argument("--output-file", type=str, help="Output JSONL file name.") parser.add_argument( @@ -1542,7 +1547,7 @@ def __call__(self, parser, namespace, values, option_string=None): metavar='{"key1": "value1", "key2": "value2"}', type=str, help="Append given JSON object to the request payload. You can use this to specify" - "additional generate params like sampling params.", + "additional generate params like sampling params.", ) parser.add_argument( "--apply-chat-template", @@ -1553,7 +1558,7 @@ def __call__(self, parser, namespace, values, option_string=None): "--profile", action="store_true", help="Use Torch Profiler. The endpoint must be launched with " - "SGLANG_TORCH_PROFILER_DIR to enable profiler.", + "SGLANG_TORCH_PROFILER_DIR to enable profiler.", ) parser.add_argument( "--lora-name", From 87b122ad792bea4db8e6667cf2f30fea3c31acfb Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 12 Apr 2025 13:21:33 +0800 Subject: [PATCH 0650/1089] more --- python/sglang/bench_serving.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index 098f3355f86..aa862dc7e25 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -480,10 +480,11 @@ def get_tokenizer( def get_dataset(args, tokenizer): + num_prompts = args.num_prompts if args.dataset_name == "sharegpt": input_requests = sample_sharegpt_requests( dataset_path=args.dataset_path, - num_requests=args.num_prompts, + num_requests=num_prompts, tokenizer=tokenizer, fixed_output_len=args.sharegpt_output_len, context_len=args.sharegpt_context_len, @@ -494,7 +495,7 @@ def get_dataset(args, tokenizer): input_requests = sample_random_requests( input_len=args.random_input_len, output_len=args.random_output_len, - num_prompts=args.num_prompts, + num_prompts=num_prompts, range_ratio=args.random_range_ratio, tokenizer=tokenizer, dataset_path=args.dataset_path, From cdfff33d3e74f775162f93e135788d7ed952bacd Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 12 Apr 2025 13:21:53 +0800 Subject: [PATCH 0651/1089] more --- python/sglang/bench_serving.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index aa862dc7e25..66f93e1ad30 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -480,7 +480,7 @@ def get_tokenizer( def get_dataset(args, tokenizer): - num_prompts = args.num_prompts + num_prompts = args.num_prompts + args.skip_num_prompts if args.dataset_name == "sharegpt": input_requests = sample_sharegpt_requests( dataset_path=args.dataset_path, @@ -512,6 +512,7 @@ def get_dataset(args, tokenizer): ) else: raise ValueError(f"Unknown dataset: {args.dataset_name}") + input_requests = input_requests[args.skip_num_prompts:] return input_requests From 04bd9832d5336c6e0b5023df4d02cf64ebd457fc Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 12 Apr 2025 13:22:49 +0800 Subject: [PATCH 0652/1089] more --- python/sglang/bench_serving.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index 66f93e1ad30..a94bac9f210 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -1562,6 +1562,11 @@ def __call__(self, parser, namespace, values, option_string=None): help="Use Torch Profiler. The endpoint must be launched with " "SGLANG_TORCH_PROFILER_DIR to enable profiler.", ) + parser.add_argument( + "--enable-expert-distribution-record", + action="store_true", + help="Enable expert distribution recorder", + ) parser.add_argument( "--lora-name", type=str, From b0e912e22a4ee6cd299a1c4e5442479cd3bdd619 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 12 Apr 2025 13:22:57 +0800 Subject: [PATCH 0653/1089] more --- python/sglang/bench_serving.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index a94bac9f210..f2dafe1b6b4 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -1565,7 +1565,7 @@ def __call__(self, parser, namespace, values, option_string=None): parser.add_argument( "--enable-expert-distribution-record", action="store_true", - help="Enable expert distribution recorder", + help="Enable expert distribution recording", ) parser.add_argument( "--lora-name", From f11b031a2487feff3ffd46e3a3f75530abc3c3d6 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 12 Apr 2025 13:23:58 +0800 Subject: [PATCH 0654/1089] more --- python/sglang/bench_serving.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index f2dafe1b6b4..4903af7126d 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -975,6 +975,7 @@ async def benchmark( lora_names: List[str], extra_request_body: Dict[str, Any], profile: bool, + enable_expert_distribution_recording: bool = False, pd_seperated: bool = False, flush_cache: bool = False, ): @@ -1392,6 +1393,7 @@ def run_benchmark(args_: argparse.Namespace): lora_names=args.lora_name, extra_request_body=extra_request_body, profile=args.profile, + enable_expert_distribution_recording=args.enable_expert_distribution_recording, pd_seperated=args.pd_seperated, flush_cache=args.flush_cache, ) @@ -1563,7 +1565,7 @@ def __call__(self, parser, namespace, values, option_string=None): "SGLANG_TORCH_PROFILER_DIR to enable profiler.", ) parser.add_argument( - "--enable-expert-distribution-record", + "--enable-expert-distribution-recording", action="store_true", help="Enable expert distribution recording", ) From 4bc3df7fb96336c65dbf6a12fdf2f345eb4e5ab5 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 12 Apr 2025 13:25:20 +0800 Subject: [PATCH 0655/1089] more --- python/sglang/bench_serving.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index 4903af7126d..b15ea549c79 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -975,7 +975,7 @@ async def benchmark( lora_names: List[str], extra_request_body: Dict[str, Any], profile: bool, - enable_expert_distribution_recording: bool = False, + enable_expert_distribution_record: bool = False, pd_seperated: bool = False, flush_cache: bool = False, ): @@ -1041,6 +1041,12 @@ async def limited_request_func(request_func_input, pbar): time.sleep(1.0) + if enable_expert_distribution_record: + print("Starting expert distribution record...") + output = await async_request_profile( + api_url=base_url + "/start_expert_distribution_record" + ) + assert output.success # Start profiler if profile: print("Starting profiler...") @@ -1085,6 +1091,16 @@ async def limited_request_func(request_func_input, pbar): profile_output = await async_request_profile(api_url=base_url + "/stop_profile") if profile_output.success: print("Profiler stopped") + if enable_expert_distribution_record: + print("Stopping expert distribution record...") + output = await async_request_profile( + api_url=base_url + "/dump_expert_distribution_record" + ) + assert output.success + output = await async_request_profile( + api_url=base_url + "/stop_expert_distribution_record" + ) + assert output.success if pbar is not None: pbar.close() @@ -1393,7 +1409,7 @@ def run_benchmark(args_: argparse.Namespace): lora_names=args.lora_name, extra_request_body=extra_request_body, profile=args.profile, - enable_expert_distribution_recording=args.enable_expert_distribution_recording, + enable_expert_distribution_record=args.enable_expert_distribution_record, pd_seperated=args.pd_seperated, flush_cache=args.flush_cache, ) @@ -1565,9 +1581,9 @@ def __call__(self, parser, namespace, values, option_string=None): "SGLANG_TORCH_PROFILER_DIR to enable profiler.", ) parser.add_argument( - "--enable-expert-distribution-recording", + "--enable-expert-distribution-record", action="store_true", - help="Enable expert distribution recording", + help="Enable expert distribution record", ) parser.add_argument( "--lora-name", From 6005bc4d50d000e9d1fa5c6e1f37ae9315e3fb30 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 12 Apr 2025 13:26:19 +0800 Subject: [PATCH 0656/1089] more --- python/sglang/srt/managers/expert_distribution.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index a53b0d29cc9..9eff26aa818 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -1,11 +1,11 @@ import logging +import os from abc import ABC from contextlib import contextmanager from copy import deepcopy from typing import Any, List, Optional, Type import torch - from sglang.srt.managers.expert_location import ExpertLocationMetadata from sglang.srt.server_args import ServerArgs from sglang.srt.utils import Withable, get_bool_env_var @@ -189,8 +189,8 @@ def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor): torch.cuda.synchronize() num_recv_tokens_per_expert_list = [ - 0 - ] * self._expert_location_metadata.num_local_physical_experts + 0 + ] * self._expert_location_metadata.num_local_physical_experts for token_record in topk_ids_list: for global_physical_expert_idx in token_record: local_physical_expert_idx = ( @@ -281,6 +281,7 @@ def postprocess_dumps( def __init__(self, expert_location_metadata: "ExpertLocationMetadata", rank: int): super().__init__(expert_location_metadata, rank) + self._save_dir = os.environ.get("SGLANG_EXPERT_DISTRIBUTION_RECORDER_SAVE_DIR") self._records = [] def get_single_pass_gatherer_keys(self): From 93105dda37190c58d4311aff45506be95042c96a Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 12 Apr 2025 13:26:29 +0800 Subject: [PATCH 0657/1089] more --- python/sglang/srt/managers/expert_distribution.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 9eff26aa818..305f5ecfbeb 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -300,6 +300,7 @@ def append( gatherer_key: str, single_pass_physical_count: torch.Tensor, ): + single_pass_physical_count = single_pass_physical_count.to("cpu") self._records.append( dict( forward_pass_id=forward_pass_id, From 89c3225ac5eb7ac03d3cc8cfe4019ee90994d237 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 12 Apr 2025 13:26:44 +0800 Subject: [PATCH 0658/1089] more --- python/sglang/srt/managers/expert_distribution.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 305f5ecfbeb..067b9f47b59 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -301,12 +301,15 @@ def append( single_pass_physical_count: torch.Tensor, ): single_pass_physical_count = single_pass_physical_count.to("cpu") + if self._save_dir is None: + single_pass_physical_count = single_pass_physical_count.tolist() + self._records.append( dict( forward_pass_id=forward_pass_id, rank=self._rank, gatherer_key=gatherer_key, - physical_count=single_pass_physical_count.tolist(), + physical_count=single_pass_physical_count, ) ) From 30f011ba552ed590a7ee316b0ca57d904c93abd1 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 12 Apr 2025 13:27:33 +0800 Subject: [PATCH 0659/1089] more --- python/sglang/srt/managers/expert_distribution.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 067b9f47b59..584b05bc857 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -3,6 +3,7 @@ from abc import ABC from contextlib import contextmanager from copy import deepcopy +from pathlib import Path from typing import Any, List, Optional, Type import torch @@ -317,7 +318,12 @@ def reset(self): self._records.clear() def dump(self): - return deepcopy(self._records) + if self._save_dir is None: + return deepcopy(self._records) + else: + path_output = Path(self._save_dir) / TODO + torch.save(self._records, str(path_output)) + return [dict(path_output=str(path_output))] class _StatAccumulator(_Accumulator): From 68894bfa40fed6d9d0b6a75ab6e48bbd1e26e58e Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 12 Apr 2025 13:28:00 +0800 Subject: [PATCH 0660/1089] more --- python/sglang/srt/managers/expert_distribution.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 584b05bc857..d5335a44a26 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -282,9 +282,12 @@ def postprocess_dumps( def __init__(self, expert_location_metadata: "ExpertLocationMetadata", rank: int): super().__init__(expert_location_metadata, rank) - self._save_dir = os.environ.get("SGLANG_EXPERT_DISTRIBUTION_RECORDER_SAVE_DIR") self._records = [] + self._save_dir = os.environ.get("SGLANG_EXPERT_DISTRIBUTION_RECORDER_SAVE_DIR") + if self._save_dir is not None and not Path(self._save_dir).exists(): + Path(self._save_dir).mkdir(parents=True, exist_ok=True) + def get_single_pass_gatherer_keys(self): if False: # TODO `server_args.enable_two_batch_overlap` return [_SINGLE_PASS_GATHERER_KEY_PRIMARY, "child_a", "child_b"] From 2a4a0d93faeb1146fc49d2be383b7abb961d2e31 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 12 Apr 2025 13:28:23 +0800 Subject: [PATCH 0661/1089] more --- python/sglang/srt/managers/expert_distribution.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index d5335a44a26..9c539dce639 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -1,5 +1,6 @@ import logging import os +import time from abc import ABC from contextlib import contextmanager from copy import deepcopy @@ -324,7 +325,7 @@ def dump(self): if self._save_dir is None: return deepcopy(self._records) else: - path_output = Path(self._save_dir) / TODO + path_output = Path(self._save_dir) / f"{time.time()}-{self._rank}.pt" torch.save(self._records, str(path_output)) return [dict(path_output=str(path_output))] From ea8c68ed75d37dd37780fda128bd48aeddd0d056 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 12 Apr 2025 13:45:09 +0800 Subject: [PATCH 0662/1089] more --- python/sglang/srt/managers/expert_distribution.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 9c539dce639..deb158a5115 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -326,6 +326,7 @@ def dump(self): return deepcopy(self._records) else: path_output = Path(self._save_dir) / f"{time.time()}-{self._rank}.pt" + logger.info(f"Write expert distribution to {path_output}") torch.save(self._records, str(path_output)) return [dict(path_output=str(path_output))] From d7f7f954f4113c87902ed3a54d6d5d1720b20bbf Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 12 Apr 2025 16:20:14 +0800 Subject: [PATCH 0663/1089] cherry pick --- python/sglang/srt/managers/eplb_simulator.py | 218 +++++++++++++++++++ 1 file changed, 218 insertions(+) create mode 100644 python/sglang/srt/managers/eplb_simulator.py diff --git a/python/sglang/srt/managers/eplb_simulator.py b/python/sglang/srt/managers/eplb_simulator.py new file mode 100644 index 00000000000..c9d02b488e3 --- /dev/null +++ b/python/sglang/srt/managers/eplb_simulator.py @@ -0,0 +1,218 @@ +# TODO where to put this file? +from collections import defaultdict +from dataclasses import dataclass +from pathlib import Path + +import einops +import torch +from sglang.srt.managers import deepseek_eplb +from sglang.srt.managers.expert_location import ExpertLocationMetadata, ModelConfigForExpertLocation +from tqdm.auto import tqdm + + +@dataclass +class MyServerArgs: + chunked_prefill_size: int + ep_num_redundant_experts: int + nnodes: int + tp_size: int + enable_expert_location_by_eplb: bool + + +@dataclass +class MyExpertLocationMetadata: + physical_to_logical_map: torch.Tensor # (layers, num_physical_experts) + logical_to_all_physical_map: torch.Tensor # (layers, num_logical_experts, X) + + @staticmethod + def init_by_eplb(server_args: MyServerArgs, logical_count: torch.Tensor, num_physical_experts: int): + model_config_for_expert_location = _MY_MODEL_CONFIG_FOR_EXPERT_LOCATION + + physical_to_logical_map, logical_to_all_physical_map, _ = ( + deepseek_eplb.rebalance_experts( + weight=logical_count, + num_replicas=num_physical_experts, + num_groups=model_config_for_expert_location.num_groups, + num_nodes=server_args.nnodes, + num_gpus=server_args.tp_size, + ) + ) + + return MyExpertLocationMetadata( + physical_to_logical_map=physical_to_logical_map, + logical_to_all_physical_map=logical_to_all_physical_map, + ) + + +# https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/config.json +_MY_MODEL_CONFIG_FOR_EXPERT_LOCATION = ModelConfigForExpertLocation( + num_layers=61, + num_logical_experts=256, + num_groups=8, +) + + +def read_logical_count_of_seq(dir_data: Path): + physical_count_of_forward_pass_id_and_rank = defaultdict(lambda: defaultdict()) + for path in tqdm(list(dir_data.glob("*.pt"))): + for record in torch.load(path, weights_only=True): + assert physical_count_of_forward_pass_id_and_rank[record["forward_pass_id"]].get(record["rank"]) is None + physical_count_of_forward_pass_id_and_rank[record["forward_pass_id"]][record["rank"]] = record[ + "physical_count"] + print(len(physical_count_of_forward_pass_id_and_rank)) + + items = [] + for forward_pass_id, physical_count_of_rank in sorted(physical_count_of_forward_pass_id_and_rank.items()): + physical_count_of_rank_tensor = torch.cat([ + physical_count + for rank, physical_count + in sorted(physical_count_of_rank.items()) + ], dim=-1) + items.append(physical_count_of_rank_tensor) + + logical_count_of_seq = torch.stack(items) + print(f"{logical_count_of_seq.shape=}") + + return logical_count_of_seq + + +def scan_combinations( + logical_count_of_seq: torch.Tensor, +): + server_args_list = [ + *[ + MyServerArgs( + chunked_prefill_size=8192 * 32, + ep_num_redundant_experts=32, + nnodes=4, + tp_size=32, + enable_expert_location_by_eplb=enable_expert_location_by_eplb, + ) + for enable_expert_location_by_eplb in [False, True] + ] + ] + + for server_args in server_args_list: + mean_utilization_rate = simulate_execution(logical_count_of_seq=logical_count_of_seq, server_args=server_args) + print(f"{server_args=} {mean_utilization_rate=:.2f}") + + +def simulate_execution( + logical_count_of_seq: torch.Tensor, + server_args: MyServerArgs, +): + model_config_for_expert_location = _MY_MODEL_CONFIG_FOR_EXPERT_LOCATION + num_physical_expert = model_config_for_expert_location.num_logical_experts + server_args.ep_num_redundant_experts + + logical_count_of_batch = simulate_batching( + logical_count_of_seq=logical_count_of_seq, + chunked_prefill_size=server_args.chunked_prefill_size, + ) + + if server_args.enable_expert_location_by_eplb: + expert_location_metadata = MyExpertLocationMetadata.init_by_eplb( + server_args, + logical_count=einops.einsum(logical_count_of_seq, "num_seq num_layer num_expert -> num_layer num_expert"), + num_physical_experts=num_physical_expert, + ) + physical_count_of_batch = simulate_logical_to_physical( + logical_count_of_whatever=logical_count_of_batch, + logical_to_all_physical_map=expert_location_metadata.logical_to_all_physical_map, + num_physical_expert=num_physical_expert, + ) + else: + physical_count_of_batch = logical_count_of_batch + + gpu_physical_count_of_batch = compute_gpu_physical_count_of_batch( + physical_count_of_batch=physical_count_of_batch, + num_gpu=server_args.tp_size, + ) + + utilization_rate = compute_utilization_rate( + gpu_physical_count_of_batch=gpu_physical_count_of_batch, + ) + + # NOTE: first 3 layers are MLP, thus those parts are not meaningful + mean_utilization_rate = torch.mean(utilization_rate).item() + + return mean_utilization_rate + + +def simulate_batching( + logical_count_of_seq: torch.Tensor, # (num_seq, num_layer, num_logical_expert) + chunked_prefill_size: int, +) -> torch.Tensor: + """output: (num_batch, num_layer, num_logical_expert)""" + tensor_chunks = chunker( + logical_count_of_seq, + state_reducer=lambda count, tensor: count + compute_num_token(tensor).item(), + should_chunk=lambda count: count > chunked_prefill_size, + ) + return torch.stack([torch.stack(tensor_chunk).sum(dim=0) for tensor_chunk in tensor_chunks]) + + +def simulate_logical_to_physical( + logical_count_of_whatever: torch.Tensor, # (*, num_layer, num_logical_expert) + logical_to_all_physical_map: torch.Tensor, # (num_layer, num_logical_experts, X) + num_physical_expert: int, +): + """output: (*, num_layer, num_physical_expert)""" + num_whatever, num_layer, num_logical_expert = logical_count_of_whatever.shape + + physical_count_of_whatever = torch.zeros( + (num_whatever, num_layer, num_physical_expert), + dtype=logical_to_all_physical_map.dtype, + ) + + for layer_id in range(num_layer): + for logical_expert_id in range(num_logical_expert): + all_physical_expert_ids = ExpertLocationMetadata.logical_to_all_physical_raw( + logical_to_all_physical_map, layer_id, logical_expert_id + ) + for physical_expert_id in all_physical_expert_ids: + physical_count_of_whatever[:, layer_id, physical_expert_id] += \ + logical_count_of_whatever[:, layer_id, logical_expert_id] / len(all_physical_expert_ids) + + return physical_count_of_whatever + + +def compute_gpu_physical_count_of_batch( + physical_count_of_batch: torch.Tensor, # (num_batch, num_layer, num_physical_expert) + num_gpu: int, +): + """output: gpu_physical_count_of_batch (num_batch, num_layer, num_gpu)""" + return einops.reduce( + physical_count_of_batch, + "num_batch num_layer (num_gpu num_expert_per_gpu) -> num_batch num_layer num_gpu", + "sum", + num_gpu=num_gpu, + ) + + +def compute_utilization_rate( + gpu_physical_count_of_batch: torch.Tensor, # (num_batch, num_layer, num_gpu) +): + """output: utilization_rate (num_batch, num_layer)""" + max_gpu_physical_count = einops.reduce(gpu_physical_count_of_batch, + "num_batch num_layer num_gpu -> num_batch num_layer", 'max') + avg_gpu_physical_count = einops.reduce(gpu_physical_count_of_batch, + "num_batch num_layer num_gpu -> num_batch num_layer", 'mean') + return avg_gpu_physical_count / max_gpu_physical_count + + +def compute_num_token(whatever_with_num_layer_and_num_expert: torch.Tensor): + return whatever_with_num_layer_and_num_expert[..., -1, :].sum(dim=-1) + + +def chunker(objects, state_reducer, should_chunk): + state = 0 + outputs = [] + for obj in objects: + outputs.append(obj) + state = state_reducer(state, obj) + if should_chunk(state): + yield outputs + outputs = [] + state = 0 + if len(outputs) > 0: + yield outputs From 6aaf82f74560e17b604cb181676800ef4d5ee13b Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 12 Apr 2025 16:22:39 +0800 Subject: [PATCH 0664/1089] more --- python/sglang/srt/managers/eplb_simulator.py | 23 ++++++++++---------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/python/sglang/srt/managers/eplb_simulator.py b/python/sglang/srt/managers/eplb_simulator.py index c9d02b488e3..1ff9fe27270 100644 --- a/python/sglang/srt/managers/eplb_simulator.py +++ b/python/sglang/srt/managers/eplb_simulator.py @@ -77,7 +77,7 @@ def read_logical_count_of_seq(dir_data: Path): def scan_combinations( - logical_count_of_seq: torch.Tensor, + logical_count_of_seq: torch.Tensor, ): server_args_list = [ *[ @@ -98,8 +98,8 @@ def scan_combinations( def simulate_execution( - logical_count_of_seq: torch.Tensor, - server_args: MyServerArgs, + logical_count_of_seq: torch.Tensor, + server_args: MyServerArgs, ): model_config_for_expert_location = _MY_MODEL_CONFIG_FOR_EXPERT_LOCATION num_physical_expert = model_config_for_expert_location.num_logical_experts + server_args.ep_num_redundant_experts @@ -139,8 +139,8 @@ def simulate_execution( def simulate_batching( - logical_count_of_seq: torch.Tensor, # (num_seq, num_layer, num_logical_expert) - chunked_prefill_size: int, + logical_count_of_seq: torch.Tensor, # (num_seq, num_layer, num_logical_expert) + chunked_prefill_size: int, ) -> torch.Tensor: """output: (num_batch, num_layer, num_logical_expert)""" tensor_chunks = chunker( @@ -152,9 +152,9 @@ def simulate_batching( def simulate_logical_to_physical( - logical_count_of_whatever: torch.Tensor, # (*, num_layer, num_logical_expert) - logical_to_all_physical_map: torch.Tensor, # (num_layer, num_logical_experts, X) - num_physical_expert: int, + logical_count_of_whatever: torch.Tensor, # (*, num_layer, num_logical_expert) + logical_to_all_physical_map: torch.Tensor, # (num_layer, num_logical_experts, X) + num_physical_expert: int, ): """output: (*, num_layer, num_physical_expert)""" num_whatever, num_layer, num_logical_expert = logical_count_of_whatever.shape @@ -177,8 +177,8 @@ def simulate_logical_to_physical( def compute_gpu_physical_count_of_batch( - physical_count_of_batch: torch.Tensor, # (num_batch, num_layer, num_physical_expert) - num_gpu: int, + physical_count_of_batch: torch.Tensor, # (num_batch, num_layer, num_physical_expert) + num_gpu: int, ): """output: gpu_physical_count_of_batch (num_batch, num_layer, num_gpu)""" return einops.reduce( @@ -190,9 +190,10 @@ def compute_gpu_physical_count_of_batch( def compute_utilization_rate( - gpu_physical_count_of_batch: torch.Tensor, # (num_batch, num_layer, num_gpu) + gpu_physical_count_of_batch: torch.Tensor, # (num_batch, num_layer, num_gpu) ): """output: utilization_rate (num_batch, num_layer)""" + gpu_physical_count_of_batch = gpu_physical_count_of_batch.float() max_gpu_physical_count = einops.reduce(gpu_physical_count_of_batch, "num_batch num_layer num_gpu -> num_batch num_layer", 'max') avg_gpu_physical_count = einops.reduce(gpu_physical_count_of_batch, From 0243942fe0aa349b6627e48ed3b43d17cbf3d78f Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 12 Apr 2025 16:44:12 +0800 Subject: [PATCH 0665/1089] more --- python/sglang/srt/managers/eplb_simulator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/eplb_simulator.py b/python/sglang/srt/managers/eplb_simulator.py index 1ff9fe27270..9686740acfa 100644 --- a/python/sglang/srt/managers/eplb_simulator.py +++ b/python/sglang/srt/managers/eplb_simulator.py @@ -161,7 +161,7 @@ def simulate_logical_to_physical( physical_count_of_whatever = torch.zeros( (num_whatever, num_layer, num_physical_expert), - dtype=logical_to_all_physical_map.dtype, + dtype=torch.float32, ) for layer_id in range(num_layer): From 6fbbcc9281b3374771c6912edcec30cca551e3c8 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 12 Apr 2025 16:47:21 +0800 Subject: [PATCH 0666/1089] more --- python/sglang/srt/managers/eplb_simulator.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/managers/eplb_simulator.py b/python/sglang/srt/managers/eplb_simulator.py index 9686740acfa..6a469a6810a 100644 --- a/python/sglang/srt/managers/eplb_simulator.py +++ b/python/sglang/srt/managers/eplb_simulator.py @@ -59,7 +59,7 @@ def read_logical_count_of_seq(dir_data: Path): assert physical_count_of_forward_pass_id_and_rank[record["forward_pass_id"]].get(record["rank"]) is None physical_count_of_forward_pass_id_and_rank[record["forward_pass_id"]][record["rank"]] = record[ "physical_count"] - print(len(physical_count_of_forward_pass_id_and_rank)) + # print(len(physical_count_of_forward_pass_id_and_rank)) items = [] for forward_pass_id, physical_count_of_rank in sorted(physical_count_of_forward_pass_id_and_rank.items()): @@ -71,7 +71,7 @@ def read_logical_count_of_seq(dir_data: Path): items.append(physical_count_of_rank_tensor) logical_count_of_seq = torch.stack(items) - print(f"{logical_count_of_seq.shape=}") + # print(f"{logical_count_of_seq.shape=}") return logical_count_of_seq @@ -94,7 +94,7 @@ def scan_combinations( for server_args in server_args_list: mean_utilization_rate = simulate_execution(logical_count_of_seq=logical_count_of_seq, server_args=server_args) - print(f"{server_args=} {mean_utilization_rate=:.2f}") + # print(f"{server_args=} {mean_utilization_rate=:.2f}") def simulate_execution( @@ -108,6 +108,7 @@ def simulate_execution( logical_count_of_seq=logical_count_of_seq, chunked_prefill_size=server_args.chunked_prefill_size, ) + # print(f"hi {logical_count_of_batch=}") if server_args.enable_expert_location_by_eplb: expert_location_metadata = MyExpertLocationMetadata.init_by_eplb( @@ -115,11 +116,13 @@ def simulate_execution( logical_count=einops.einsum(logical_count_of_seq, "num_seq num_layer num_expert -> num_layer num_expert"), num_physical_experts=num_physical_expert, ) + # print(f"hi {expert_location_metadata=}") physical_count_of_batch = simulate_logical_to_physical( logical_count_of_whatever=logical_count_of_batch, logical_to_all_physical_map=expert_location_metadata.logical_to_all_physical_map, num_physical_expert=num_physical_expert, ) + # print(f"hi {physical_count_of_batch=}") else: physical_count_of_batch = logical_count_of_batch @@ -127,10 +130,12 @@ def simulate_execution( physical_count_of_batch=physical_count_of_batch, num_gpu=server_args.tp_size, ) + # print(f"hi {gpu_physical_count_of_batch=}") utilization_rate = compute_utilization_rate( gpu_physical_count_of_batch=gpu_physical_count_of_batch, ) + # print(f"hi {utilization_rate=}") # NOTE: first 3 layers are MLP, thus those parts are not meaningful mean_utilization_rate = torch.mean(utilization_rate).item() @@ -198,7 +203,7 @@ def compute_utilization_rate( "num_batch num_layer num_gpu -> num_batch num_layer", 'max') avg_gpu_physical_count = einops.reduce(gpu_physical_count_of_batch, "num_batch num_layer num_gpu -> num_batch num_layer", 'mean') - return avg_gpu_physical_count / max_gpu_physical_count + return (avg_gpu_physical_count + 1e-5) / (max_gpu_physical_count + 1e-5) def compute_num_token(whatever_with_num_layer_and_num_expert: torch.Tensor): From c445bcb049c5c8ce29e8f2071feb0b60b9b9dcf1 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 12 Apr 2025 16:47:49 +0800 Subject: [PATCH 0667/1089] more --- python/sglang/srt/managers/eplb_simulator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/eplb_simulator.py b/python/sglang/srt/managers/eplb_simulator.py index 6a469a6810a..d21ac51e312 100644 --- a/python/sglang/srt/managers/eplb_simulator.py +++ b/python/sglang/srt/managers/eplb_simulator.py @@ -71,7 +71,7 @@ def read_logical_count_of_seq(dir_data: Path): items.append(physical_count_of_rank_tensor) logical_count_of_seq = torch.stack(items) - # print(f"{logical_count_of_seq.shape=}") + print(f"{logical_count_of_seq.shape=}") return logical_count_of_seq @@ -94,7 +94,7 @@ def scan_combinations( for server_args in server_args_list: mean_utilization_rate = simulate_execution(logical_count_of_seq=logical_count_of_seq, server_args=server_args) - # print(f"{server_args=} {mean_utilization_rate=:.2f}") + print(f"{server_args=} {mean_utilization_rate=:.2f}") def simulate_execution( From 47234f1e151024384b9a87a1e9fdf3b223f193c6 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 12 Apr 2025 16:48:21 +0800 Subject: [PATCH 0668/1089] more --- python/sglang/srt/managers/eplb_simulator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/eplb_simulator.py b/python/sglang/srt/managers/eplb_simulator.py index d21ac51e312..f297694e43e 100644 --- a/python/sglang/srt/managers/eplb_simulator.py +++ b/python/sglang/srt/managers/eplb_simulator.py @@ -102,7 +102,6 @@ def simulate_execution( server_args: MyServerArgs, ): model_config_for_expert_location = _MY_MODEL_CONFIG_FOR_EXPERT_LOCATION - num_physical_expert = model_config_for_expert_location.num_logical_experts + server_args.ep_num_redundant_experts logical_count_of_batch = simulate_batching( logical_count_of_seq=logical_count_of_seq, @@ -111,6 +110,7 @@ def simulate_execution( # print(f"hi {logical_count_of_batch=}") if server_args.enable_expert_location_by_eplb: + num_physical_expert = model_config_for_expert_location.num_logical_experts + server_args.ep_num_redundant_experts expert_location_metadata = MyExpertLocationMetadata.init_by_eplb( server_args, logical_count=einops.einsum(logical_count_of_seq, "num_seq num_layer num_expert -> num_layer num_expert"), From 1acb0cc880611043a8e5110cbb00f0581730972b Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 12 Apr 2025 16:49:00 +0800 Subject: [PATCH 0669/1089] more --- python/sglang/srt/managers/eplb_simulator.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/eplb_simulator.py b/python/sglang/srt/managers/eplb_simulator.py index f297694e43e..88cbc21a11a 100644 --- a/python/sglang/srt/managers/eplb_simulator.py +++ b/python/sglang/srt/managers/eplb_simulator.py @@ -82,12 +82,13 @@ def scan_combinations( server_args_list = [ *[ MyServerArgs( - chunked_prefill_size=8192 * 32, + chunked_prefill_size=chunked_prefill_size_per_gpu * 32, ep_num_redundant_experts=32, nnodes=4, tp_size=32, enable_expert_location_by_eplb=enable_expert_location_by_eplb, ) + for chunked_prefill_size_per_gpu in [8192, 4096, 2048, 1024] for enable_expert_location_by_eplb in [False, True] ] ] From f56d9f6a2dc2497166a9fc6679e0abe17c65f68c Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 12 Apr 2025 16:49:41 +0800 Subject: [PATCH 0670/1089] more --- python/sglang/srt/managers/eplb_simulator.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/eplb_simulator.py b/python/sglang/srt/managers/eplb_simulator.py index 88cbc21a11a..68a6bf0f424 100644 --- a/python/sglang/srt/managers/eplb_simulator.py +++ b/python/sglang/srt/managers/eplb_simulator.py @@ -88,12 +88,13 @@ def scan_combinations( tp_size=32, enable_expert_location_by_eplb=enable_expert_location_by_eplb, ) - for chunked_prefill_size_per_gpu in [8192, 4096, 2048, 1024] + for chunked_prefill_size_per_gpu in [1024, 2048, 4096, 8192, 16384] for enable_expert_location_by_eplb in [False, True] ] ] for server_args in server_args_list: + print() mean_utilization_rate = simulate_execution(logical_count_of_seq=logical_count_of_seq, server_args=server_args) print(f"{server_args=} {mean_utilization_rate=:.2f}") @@ -108,7 +109,7 @@ def simulate_execution( logical_count_of_seq=logical_count_of_seq, chunked_prefill_size=server_args.chunked_prefill_size, ) - # print(f"hi {logical_count_of_batch=}") + print(f"{logical_count_of_batch.shape=}") if server_args.enable_expert_location_by_eplb: num_physical_expert = model_config_for_expert_location.num_logical_experts + server_args.ep_num_redundant_experts From aff5aa0985005630199c58881254fb6df5ad9d34 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 12 Apr 2025 17:09:46 +0800 Subject: [PATCH 0671/1089] more --- python/sglang/srt/managers/eplb_simulator.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/managers/eplb_simulator.py b/python/sglang/srt/managers/eplb_simulator.py index 68a6bf0f424..77c0ded9c60 100644 --- a/python/sglang/srt/managers/eplb_simulator.py +++ b/python/sglang/srt/managers/eplb_simulator.py @@ -83,11 +83,13 @@ def scan_combinations( *[ MyServerArgs( chunked_prefill_size=chunked_prefill_size_per_gpu * 32, - ep_num_redundant_experts=32, - nnodes=4, - tp_size=32, + ep_num_redundant_experts=ep_num_redundant_experts, + nnodes=nnodes, + tp_size=8 * nnodes, enable_expert_location_by_eplb=enable_expert_location_by_eplb, ) + for ep_num_redundant_experts in [32, 64] + for nnodes in [2, 4] for chunked_prefill_size_per_gpu in [1024, 2048, 4096, 8192, 16384] for enable_expert_location_by_eplb in [False, True] ] From 97d6529e2e6cffd2030d6f1d9fa0514386208440 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 12 Apr 2025 17:10:43 +0800 Subject: [PATCH 0672/1089] more --- python/sglang/srt/managers/eplb_simulator.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/eplb_simulator.py b/python/sglang/srt/managers/eplb_simulator.py index 77c0ded9c60..247cfc500a2 100644 --- a/python/sglang/srt/managers/eplb_simulator.py +++ b/python/sglang/srt/managers/eplb_simulator.py @@ -1,4 +1,6 @@ # TODO where to put this file? +import polars as pl +import dataclasses from collections import defaultdict from dataclasses import dataclass from pathlib import Path @@ -88,17 +90,23 @@ def scan_combinations( tp_size=8 * nnodes, enable_expert_location_by_eplb=enable_expert_location_by_eplb, ) - for ep_num_redundant_experts in [32, 64] - for nnodes in [2, 4] + for ep_num_redundant_experts in [0, 32, 64] + for nnodes in [1, 2, 4] for chunked_prefill_size_per_gpu in [1024, 2048, 4096, 8192, 16384] for enable_expert_location_by_eplb in [False, True] ] ] + rows = [] for server_args in server_args_list: print() mean_utilization_rate = simulate_execution(logical_count_of_seq=logical_count_of_seq, server_args=server_args) print(f"{server_args=} {mean_utilization_rate=:.2f}") + rows.append(dict(**dataclasses.asdict(server_args), mean_utilization_rate=mean_utilization_rate)) + break + + df = pl.DataFrame(rows) + return df def simulate_execution( From 620fc0ad47d0bc5393eb8ff0863fb59c114df268 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 12 Apr 2025 17:11:41 +0800 Subject: [PATCH 0673/1089] more --- python/sglang/srt/managers/eplb_simulator.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/eplb_simulator.py b/python/sglang/srt/managers/eplb_simulator.py index 247cfc500a2..0aaf31a9a6a 100644 --- a/python/sglang/srt/managers/eplb_simulator.py +++ b/python/sglang/srt/managers/eplb_simulator.py @@ -1,11 +1,11 @@ # TODO where to put this file? -import polars as pl import dataclasses from collections import defaultdict from dataclasses import dataclass from pathlib import Path import einops +import polars as pl import torch from sglang.srt.managers import deepseek_eplb from sglang.srt.managers.expert_location import ExpertLocationMetadata, ModelConfigForExpertLocation @@ -103,7 +103,6 @@ def scan_combinations( mean_utilization_rate = simulate_execution(logical_count_of_seq=logical_count_of_seq, server_args=server_args) print(f"{server_args=} {mean_utilization_rate=:.2f}") rows.append(dict(**dataclasses.asdict(server_args), mean_utilization_rate=mean_utilization_rate)) - break df = pl.DataFrame(rows) return df From e4c91b113a9a48067b2ab4811f52cefff9b77462 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 12 Apr 2025 17:16:17 +0800 Subject: [PATCH 0674/1089] more --- python/sglang/srt/managers/eplb_simulator.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/eplb_simulator.py b/python/sglang/srt/managers/eplb_simulator.py index 0aaf31a9a6a..fb85829dd11 100644 --- a/python/sglang/srt/managers/eplb_simulator.py +++ b/python/sglang/srt/managers/eplb_simulator.py @@ -93,7 +93,10 @@ def scan_combinations( for ep_num_redundant_experts in [0, 32, 64] for nnodes in [1, 2, 4] for chunked_prefill_size_per_gpu in [1024, 2048, 4096, 8192, 16384] - for enable_expert_location_by_eplb in [False, True] + for enable_expert_location_by_eplb in [ + *([False] if ep_num_redundant_experts == 0 else []), + True, + ] ] ] From 828461c245549ab1f82d6e0a96f27fca1c2e126a Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 12 Apr 2025 22:24:29 +0800 Subject: [PATCH 0675/1089] more --- python/sglang/srt/managers/eplb_simulator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/eplb_simulator.py b/python/sglang/srt/managers/eplb_simulator.py index fb85829dd11..c7232b94222 100644 --- a/python/sglang/srt/managers/eplb_simulator.py +++ b/python/sglang/srt/managers/eplb_simulator.py @@ -91,8 +91,8 @@ def scan_combinations( enable_expert_location_by_eplb=enable_expert_location_by_eplb, ) for ep_num_redundant_experts in [0, 32, 64] - for nnodes in [1, 2, 4] - for chunked_prefill_size_per_gpu in [1024, 2048, 4096, 8192, 16384] + for nnodes in [4] + for chunked_prefill_size_per_gpu in [1024, 4096, 8192, 16384] for enable_expert_location_by_eplb in [ *([False] if ep_num_redundant_experts == 0 else []), True, From 615ed0dd9974a091278e9efbafa7bdc4c97aac37 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 12 Apr 2025 22:26:05 +0800 Subject: [PATCH 0676/1089] more --- python/sglang/srt/managers/eplb_simulator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/eplb_simulator.py b/python/sglang/srt/managers/eplb_simulator.py index c7232b94222..117aeb171e2 100644 --- a/python/sglang/srt/managers/eplb_simulator.py +++ b/python/sglang/srt/managers/eplb_simulator.py @@ -54,7 +54,7 @@ def init_by_eplb(server_args: MyServerArgs, logical_count: torch.Tensor, num_phy ) -def read_logical_count_of_seq(dir_data: Path): +def read_physical_count_of_forward_pass(dir_data: Path): physical_count_of_forward_pass_id_and_rank = defaultdict(lambda: defaultdict()) for path in tqdm(list(dir_data.glob("*.pt"))): for record in torch.load(path, weights_only=True): From 18a8bfde4f672af54c8141d41050e00d5169ca96 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 12 Apr 2025 22:26:22 +0800 Subject: [PATCH 0677/1089] more --- python/sglang/srt/managers/eplb_simulator.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/managers/eplb_simulator.py b/python/sglang/srt/managers/eplb_simulator.py index 117aeb171e2..5e667360b0d 100644 --- a/python/sglang/srt/managers/eplb_simulator.py +++ b/python/sglang/srt/managers/eplb_simulator.py @@ -72,10 +72,10 @@ def read_physical_count_of_forward_pass(dir_data: Path): ], dim=-1) items.append(physical_count_of_rank_tensor) - logical_count_of_seq = torch.stack(items) - print(f"{logical_count_of_seq.shape=}") + physical_count_of_forward_pass = torch.stack(items) + print(f"{physical_count_of_forward_pass.shape=}") - return logical_count_of_seq + return physical_count_of_forward_pass def scan_combinations( From 4f86b53fe02bde34008f2540918a58480d31c333 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 12 Apr 2025 22:30:21 +0800 Subject: [PATCH 0678/1089] more --- python/sglang/srt/managers/eplb_simulator.py | 25 ++++++++++++++------ 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/managers/eplb_simulator.py b/python/sglang/srt/managers/eplb_simulator.py index 5e667360b0d..73207f0606f 100644 --- a/python/sglang/srt/managers/eplb_simulator.py +++ b/python/sglang/srt/managers/eplb_simulator.py @@ -111,6 +111,17 @@ def scan_combinations( return df +def analyze(dir_data: Path, num_gpu: int): + physical_count_of_forward_pass = read_physical_count_of_forward_pass(dir_data) + gpu_physical_count_of_forward_pass = compute_gpu_physical_count( + physical_count_of_whatever=physical_count_of_forward_pass, + num_gpu=num_gpu, + ) + utilization_rate = compute_utilization_rate(gpu_physical_count_of_forward_pass) + print(f"{utilization_rate.shape=}") + print(dir_data, torch.mean(utilization_rate).item()) + + def simulate_execution( logical_count_of_seq: torch.Tensor, server_args: MyServerArgs, @@ -140,8 +151,8 @@ def simulate_execution( else: physical_count_of_batch = logical_count_of_batch - gpu_physical_count_of_batch = compute_gpu_physical_count_of_batch( - physical_count_of_batch=physical_count_of_batch, + gpu_physical_count_of_batch = compute_gpu_physical_count( + physical_count_of_whatever=physical_count_of_batch, num_gpu=server_args.tp_size, ) # print(f"hi {gpu_physical_count_of_batch=}") @@ -195,14 +206,14 @@ def simulate_logical_to_physical( return physical_count_of_whatever -def compute_gpu_physical_count_of_batch( - physical_count_of_batch: torch.Tensor, # (num_batch, num_layer, num_physical_expert) +def compute_gpu_physical_count( + physical_count_of_whatever: torch.Tensor, # (whatever, num_layer, num_physical_expert) num_gpu: int, ): - """output: gpu_physical_count_of_batch (num_batch, num_layer, num_gpu)""" + """output: gpu_physical_count_of_batch (whatever, num_layer, num_gpu)""" return einops.reduce( - physical_count_of_batch, - "num_batch num_layer (num_gpu num_expert_per_gpu) -> num_batch num_layer num_gpu", + physical_count_of_whatever, + "whatever num_layer (num_gpu num_expert_per_gpu) -> num_batch num_layer num_gpu", "sum", num_gpu=num_gpu, ) From 544977f6a39ab8834fea71cc28d7aed9e0d56929 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 12 Apr 2025 22:30:34 +0800 Subject: [PATCH 0679/1089] more --- python/sglang/srt/managers/eplb_simulator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/eplb_simulator.py b/python/sglang/srt/managers/eplb_simulator.py index 73207f0606f..7d0c9cf1156 100644 --- a/python/sglang/srt/managers/eplb_simulator.py +++ b/python/sglang/srt/managers/eplb_simulator.py @@ -111,7 +111,7 @@ def scan_combinations( return df -def analyze(dir_data: Path, num_gpu: int): +def analyze_actual_utilization_rate(dir_data: Path, num_gpu: int): physical_count_of_forward_pass = read_physical_count_of_forward_pass(dir_data) gpu_physical_count_of_forward_pass = compute_gpu_physical_count( physical_count_of_whatever=physical_count_of_forward_pass, From d526f4cd73ce0539367ae240c37c4840bcdaae18 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 12 Apr 2025 22:31:15 +0800 Subject: [PATCH 0680/1089] more --- python/sglang/srt/managers/eplb_simulator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/eplb_simulator.py b/python/sglang/srt/managers/eplb_simulator.py index 7d0c9cf1156..621afd44210 100644 --- a/python/sglang/srt/managers/eplb_simulator.py +++ b/python/sglang/srt/managers/eplb_simulator.py @@ -213,7 +213,7 @@ def compute_gpu_physical_count( """output: gpu_physical_count_of_batch (whatever, num_layer, num_gpu)""" return einops.reduce( physical_count_of_whatever, - "whatever num_layer (num_gpu num_expert_per_gpu) -> num_batch num_layer num_gpu", + "whatever num_layer (num_gpu num_expert_per_gpu) -> whatever num_layer num_gpu", "sum", num_gpu=num_gpu, ) From 992a155c4b933c4191cda00fc63647d70a566539 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 13 Apr 2025 12:41:22 +0800 Subject: [PATCH 0681/1089] more --- python/sglang/srt/layers/moe/ep_moe/layer.py | 51 ++++++++++++++------ 1 file changed, 35 insertions(+), 16 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 2a296be7a06..c3e65a08921 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -1,8 +1,8 @@ import logging +from contextlib import contextmanager from typing import Callable, List, Optional, Tuple import torch - from sglang.srt.managers.schedule_batch import get_global_expert_location_metadata try: @@ -236,8 +236,8 @@ def forward( correction_bias=self.correction_bias, custom_routing_function=self.custom_routing_function, expert_logical_to_rank_dispatch_physical_map=get_global_expert_location_metadata().logical_to_rank_dispatch_physical_map[ - self.tp_rank, self.layer_id, : - ], + self.tp_rank, self.layer_id, : + ], ) reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess( @@ -275,7 +275,7 @@ def forward( BLOCK_SIZE=512, ) - seg_indptr_cur_rank = seg_indptr[self.start_expert_id : self.end_expert_id + 2] + seg_indptr_cur_rank = seg_indptr[self.start_expert_id: self.end_expert_id + 2] weight_indices_cur_rank = torch.arange( 0, self.num_experts_per_partition, @@ -478,7 +478,7 @@ def _weight_loader_physical( elif shard_id == "w1": param.data[expert_id][: self.intermediate_size, :] = loaded_weight elif shard_id == "w3": - param.data[expert_id][self.intermediate_size :, :] = loaded_weight + param.data[expert_id][self.intermediate_size:, :] = loaded_weight else: raise ValueError(f"Expected shard_id w1,w2 or w3 but got {shard_id}") @@ -510,11 +510,11 @@ def _load_fp8_scale( block_n, block_k = self.block_shape[0], self.block_shape[1] if shard_id == "w1": param_data[expert_id][ - : (self.intermediate_size + block_n - 1) // block_n, : + : (self.intermediate_size + block_n - 1) // block_n, : ] = loaded_weight elif shard_id == "w3": param_data[expert_id][ - (self.intermediate_size + block_n - 1) // block_n :, : + (self.intermediate_size + block_n - 1) // block_n:, : ] = loaded_weight else: # w2 param_data[expert_id] = loaded_weight @@ -1037,9 +1037,10 @@ def forward_deepgemm_masked( ): assert self.quant_method is not None assert self.activation == "silu" - assert ( - hidden_states_fp8[0].size(0) % 4 == 0 - ), f"TMA alignment error: {hidden_states_fp8[0].size(0)}" + # NOTE HACK + # assert ( + # hidden_states_fp8[0].size(0) % 4 == 0 + # ), f"TMA alignment error: {hidden_states_fp8[0].size(0)}" # GroupGemm-0 num_groups, m, k = hidden_states_fp8[0].size() @@ -1048,9 +1049,10 @@ def forward_deepgemm_masked( gateup_output = torch.empty( (num_groups, m, n), device=hidden_states_fp8[0].device, dtype=torch.bfloat16 ) - m_grouped_gemm_fp8_fp8_bf16_nt_masked( - hidden_states_fp8, self.w13_weight_fp8, gateup_output, masked_m, expected_m - ) + with _ensure_get_col_major_tma_aligned_tensor_noop(): + m_grouped_gemm_fp8_fp8_bf16_nt_masked( + hidden_states_fp8, self.w13_weight_fp8, gateup_output, masked_m, expected_m + ) # Act down_input = torch.empty( @@ -1089,8 +1091,25 @@ def forward_deepgemm_masked( down_output = torch.empty( (num_groups, m, n), device=down_input.device, dtype=torch.bfloat16 ) - m_grouped_gemm_fp8_fp8_bf16_nt_masked( - down_input_fp8, self.w2_weight_fp8, down_output, masked_m, expected_m - ) + with _ensure_get_col_major_tma_aligned_tensor_noop(): + m_grouped_gemm_fp8_fp8_bf16_nt_masked( + down_input_fp8, self.w2_weight_fp8, down_output, masked_m, expected_m + ) return down_output + + +@contextmanager +def _ensure_get_col_major_tma_aligned_tensor_noop(): + from deep_gemm.jit_kernels import utils + + original_func = utils.get_col_major_tma_aligned_tensor + + def patched_get_col_major_tma_aligned_tensor(*args, **kwargs): + return TODO + + utils.get_col_major_tma_aligned_tensor = patched_get_col_major_tma_aligned_tensor + try: + yield + finally: + utils.get_col_major_tma_aligned_tensor = original_func From bf307389d17f70bbee01c9197f0e7d1da64d3c1c Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 13 Apr 2025 12:42:25 +0800 Subject: [PATCH 0682/1089] more --- python/sglang/srt/layers/moe/ep_moe/layer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index c3e65a08921..373724ff689 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -1105,8 +1105,10 @@ def _ensure_get_col_major_tma_aligned_tensor_noop(): original_func = utils.get_col_major_tma_aligned_tensor - def patched_get_col_major_tma_aligned_tensor(*args, **kwargs): - return TODO + def patched_get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor: + out = original_func(x) + assert x.data_ptr() == out.data_ptr(), f"get_col_major_tma_aligned_tensor is not noop ({x.data_ptr()=}, {out.data_ptr()=})" + return out utils.get_col_major_tma_aligned_tensor = patched_get_col_major_tma_aligned_tensor try: From da5acb8f7347023e739070d2c677451751729201 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 13 Apr 2025 12:52:06 +0800 Subject: [PATCH 0683/1089] fmt --- python/sglang/bench_serving.py | 34 +++---- python/sglang/srt/entrypoints/engine.py | 1 + python/sglang/srt/layers/moe/ep_moe/layer.py | 23 +++-- .../srt/layers/moe/ep_moe/token_dispatcher.py | 1 - python/sglang/srt/managers/eplb_simulator.py | 94 ++++++++++++++----- .../srt/managers/expert_distribution.py | 13 ++- python/sglang/srt/managers/expert_location.py | 53 +++++++---- python/sglang/srt/models/deepseek_v2.py | 11 ++- test/srt/test_eplb.py | 40 +++++--- 9 files changed, 179 insertions(+), 91 deletions(-) diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index b15ea549c79..439e69b912c 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -74,7 +74,7 @@ class RequestFuncOutput: def remove_prefix(text: str, prefix: str) -> str: - return text[len(prefix):] if text.startswith(prefix) else text + return text[len(prefix) :] if text.startswith(prefix) else text def remove_suffix(text: str, suffix: str) -> str: @@ -395,8 +395,8 @@ async def async_request_sglang_generate( if num_new_tokens == 0: continue adjust_itl = ( - timestamp - most_recent_timestamp - ) / num_new_tokens + timestamp - most_recent_timestamp + ) / num_new_tokens output.itl.extend([adjust_itl] * num_new_tokens) most_recent_timestamp = timestamp @@ -512,7 +512,7 @@ def get_dataset(args, tokenizer): ) else: raise ValueError(f"Unknown dataset: {args.dataset_name}") - input_requests = input_requests[args.skip_num_prompts:] + input_requests = input_requests[args.skip_num_prompts :] return input_requests @@ -936,9 +936,9 @@ def calculate_metrics( output_throughput_retokenized=sum(retokenized_output_lens) / dur_s, total_throughput=(total_input + sum(output_lens)) / dur_s, total_throughput_retokenized=(total_input + sum(retokenized_output_lens)) - / dur_s, + / dur_s, mean_ttft_ms=np.mean(ttfts or 0) - * 1000, # ttfts is empty if streaming is not supported by backend + * 1000, # ttfts is empty if streaming is not supported by backend median_ttft_ms=np.median(ttfts or 0) * 1000, std_ttft_ms=np.std(ttfts or 0) * 1000, p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000, @@ -1518,27 +1518,27 @@ def __call__(self, parser, namespace, values, option_string=None): type=float, default=0.0, help="Range of sampled ratio of input/output length, " - "used only for random dataset.", + "used only for random dataset.", ) parser.add_argument( "--request-rate", type=float, default=float("inf"), help="Number of requests per second. If this is inf, then all the requests are sent at time 0. " - "Otherwise, we use Poisson process to synthesize the request arrival times. Default is inf.", + "Otherwise, we use Poisson process to synthesize the request arrival times. Default is inf.", ) parser.add_argument( "--max-concurrency", type=int, default=None, help="Maximum number of concurrent requests. This can be used " - "to help simulate an environment where a higher level component " - "is enforcing a maximum number of concurrent requests. While the " - "--request-rate argument controls the rate at which requests are " - "initiated, this argument will control how many are actually allowed " - "to execute at a time. This means that when used in combination, the " - "actual request rate may be lower than specified with --request-rate, " - "if the server is not processing requests fast enough to keep up.", + "to help simulate an environment where a higher level component " + "is enforcing a maximum number of concurrent requests. While the " + "--request-rate argument controls the rate at which requests are " + "initiated, this argument will control how many are actually allowed " + "to execute at a time. This means that when used in combination, the " + "actual request rate may be lower than specified with --request-rate, " + "if the server is not processing requests fast enough to keep up.", ) parser.add_argument("--output-file", type=str, help="Output JSONL file name.") parser.add_argument( @@ -1567,7 +1567,7 @@ def __call__(self, parser, namespace, values, option_string=None): metavar='{"key1": "value1", "key2": "value2"}', type=str, help="Append given JSON object to the request payload. You can use this to specify" - "additional generate params like sampling params.", + "additional generate params like sampling params.", ) parser.add_argument( "--apply-chat-template", @@ -1578,7 +1578,7 @@ def __call__(self, parser, namespace, values, option_string=None): "--profile", action="store_true", help="Use Torch Profiler. The endpoint must be launched with " - "SGLANG_TORCH_PROFILER_DIR to enable profiler.", + "SGLANG_TORCH_PROFILER_DIR to enable profiler.", ) parser.add_argument( "--enable-expert-distribution-record", diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index a7cd13d7142..7e5a1813f37 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -33,6 +33,7 @@ import zmq import zmq.asyncio from PIL.Image import Image + from sglang.srt.configs.model_config import ModelConfig from sglang.srt.managers.eplb_manager import EPLBManager from sglang.srt.managers.expert_location import ExpertLocationMetadata diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 373724ff689..c4708194d40 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -3,6 +3,7 @@ from typing import Callable, List, Optional, Tuple import torch + from sglang.srt.managers.schedule_batch import get_global_expert_location_metadata try: @@ -236,8 +237,8 @@ def forward( correction_bias=self.correction_bias, custom_routing_function=self.custom_routing_function, expert_logical_to_rank_dispatch_physical_map=get_global_expert_location_metadata().logical_to_rank_dispatch_physical_map[ - self.tp_rank, self.layer_id, : - ], + self.tp_rank, self.layer_id, : + ], ) reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess( @@ -275,7 +276,7 @@ def forward( BLOCK_SIZE=512, ) - seg_indptr_cur_rank = seg_indptr[self.start_expert_id: self.end_expert_id + 2] + seg_indptr_cur_rank = seg_indptr[self.start_expert_id : self.end_expert_id + 2] weight_indices_cur_rank = torch.arange( 0, self.num_experts_per_partition, @@ -478,7 +479,7 @@ def _weight_loader_physical( elif shard_id == "w1": param.data[expert_id][: self.intermediate_size, :] = loaded_weight elif shard_id == "w3": - param.data[expert_id][self.intermediate_size:, :] = loaded_weight + param.data[expert_id][self.intermediate_size :, :] = loaded_weight else: raise ValueError(f"Expected shard_id w1,w2 or w3 but got {shard_id}") @@ -510,11 +511,11 @@ def _load_fp8_scale( block_n, block_k = self.block_shape[0], self.block_shape[1] if shard_id == "w1": param_data[expert_id][ - : (self.intermediate_size + block_n - 1) // block_n, : + : (self.intermediate_size + block_n - 1) // block_n, : ] = loaded_weight elif shard_id == "w3": param_data[expert_id][ - (self.intermediate_size + block_n - 1) // block_n:, : + (self.intermediate_size + block_n - 1) // block_n :, : ] = loaded_weight else: # w2 param_data[expert_id] = loaded_weight @@ -1051,7 +1052,11 @@ def forward_deepgemm_masked( ) with _ensure_get_col_major_tma_aligned_tensor_noop(): m_grouped_gemm_fp8_fp8_bf16_nt_masked( - hidden_states_fp8, self.w13_weight_fp8, gateup_output, masked_m, expected_m + hidden_states_fp8, + self.w13_weight_fp8, + gateup_output, + masked_m, + expected_m, ) # Act @@ -1107,7 +1112,9 @@ def _ensure_get_col_major_tma_aligned_tensor_noop(): def patched_get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor: out = original_func(x) - assert x.data_ptr() == out.data_ptr(), f"get_col_major_tma_aligned_tensor is not noop ({x.data_ptr()=}, {out.data_ptr()=})" + assert ( + x.data_ptr() == out.data_ptr() + ), f"get_col_major_tma_aligned_tensor is not noop ({x.data_ptr()=}, {out.data_ptr()=})" return out utils.get_col_major_tma_aligned_tensor = patched_get_col_major_tma_aligned_tensor diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index a2e107d622d..9a81d9d16ec 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -1,5 +1,4 @@ from sglang.srt.managers.expert_distribution import expert_distribution_recorder -from sglang.srt.utils import DeepEPMode from sglang.srt.utils import DeepEPMode, DisposibleTensor try: diff --git a/python/sglang/srt/managers/eplb_simulator.py b/python/sglang/srt/managers/eplb_simulator.py index 621afd44210..6ccf37d164b 100644 --- a/python/sglang/srt/managers/eplb_simulator.py +++ b/python/sglang/srt/managers/eplb_simulator.py @@ -7,10 +7,14 @@ import einops import polars as pl import torch -from sglang.srt.managers import deepseek_eplb -from sglang.srt.managers.expert_location import ExpertLocationMetadata, ModelConfigForExpertLocation from tqdm.auto import tqdm +from sglang.srt.managers import deepseek_eplb +from sglang.srt.managers.expert_location import ( + ExpertLocationMetadata, + ModelConfigForExpertLocation, +) + @dataclass class MyServerArgs: @@ -27,7 +31,11 @@ class MyExpertLocationMetadata: logical_to_all_physical_map: torch.Tensor # (layers, num_logical_experts, X) @staticmethod - def init_by_eplb(server_args: MyServerArgs, logical_count: torch.Tensor, num_physical_experts: int): + def init_by_eplb( + server_args: MyServerArgs, + logical_count: torch.Tensor, + num_physical_experts: int, + ): model_config_for_expert_location = _MY_MODEL_CONFIG_FOR_EXPERT_LOCATION physical_to_logical_map, logical_to_all_physical_map, _ = ( @@ -58,18 +66,28 @@ def read_physical_count_of_forward_pass(dir_data: Path): physical_count_of_forward_pass_id_and_rank = defaultdict(lambda: defaultdict()) for path in tqdm(list(dir_data.glob("*.pt"))): for record in torch.load(path, weights_only=True): - assert physical_count_of_forward_pass_id_and_rank[record["forward_pass_id"]].get(record["rank"]) is None - physical_count_of_forward_pass_id_and_rank[record["forward_pass_id"]][record["rank"]] = record[ - "physical_count"] + assert ( + physical_count_of_forward_pass_id_and_rank[ + record["forward_pass_id"] + ].get(record["rank"]) + is None + ) + physical_count_of_forward_pass_id_and_rank[record["forward_pass_id"]][ + record["rank"] + ] = record["physical_count"] # print(len(physical_count_of_forward_pass_id_and_rank)) items = [] - for forward_pass_id, physical_count_of_rank in sorted(physical_count_of_forward_pass_id_and_rank.items()): - physical_count_of_rank_tensor = torch.cat([ - physical_count - for rank, physical_count - in sorted(physical_count_of_rank.items()) - ], dim=-1) + for forward_pass_id, physical_count_of_rank in sorted( + physical_count_of_forward_pass_id_and_rank.items() + ): + physical_count_of_rank_tensor = torch.cat( + [ + physical_count + for rank, physical_count in sorted(physical_count_of_rank.items()) + ], + dim=-1, + ) items.append(physical_count_of_rank_tensor) physical_count_of_forward_pass = torch.stack(items) @@ -103,9 +121,16 @@ def scan_combinations( rows = [] for server_args in server_args_list: print() - mean_utilization_rate = simulate_execution(logical_count_of_seq=logical_count_of_seq, server_args=server_args) + mean_utilization_rate = simulate_execution( + logical_count_of_seq=logical_count_of_seq, server_args=server_args + ) print(f"{server_args=} {mean_utilization_rate=:.2f}") - rows.append(dict(**dataclasses.asdict(server_args), mean_utilization_rate=mean_utilization_rate)) + rows.append( + dict( + **dataclasses.asdict(server_args), + mean_utilization_rate=mean_utilization_rate, + ) + ) df = pl.DataFrame(rows) return df @@ -135,10 +160,16 @@ def simulate_execution( print(f"{logical_count_of_batch.shape=}") if server_args.enable_expert_location_by_eplb: - num_physical_expert = model_config_for_expert_location.num_logical_experts + server_args.ep_num_redundant_experts + num_physical_expert = ( + model_config_for_expert_location.num_logical_experts + + server_args.ep_num_redundant_experts + ) expert_location_metadata = MyExpertLocationMetadata.init_by_eplb( server_args, - logical_count=einops.einsum(logical_count_of_seq, "num_seq num_layer num_expert -> num_layer num_expert"), + logical_count=einops.einsum( + logical_count_of_seq, + "num_seq num_layer num_expert -> num_layer num_expert", + ), num_physical_experts=num_physical_expert, ) # print(f"hi {expert_location_metadata=}") @@ -178,7 +209,9 @@ def simulate_batching( state_reducer=lambda count, tensor: count + compute_num_token(tensor).item(), should_chunk=lambda count: count > chunked_prefill_size, ) - return torch.stack([torch.stack(tensor_chunk).sum(dim=0) for tensor_chunk in tensor_chunks]) + return torch.stack( + [torch.stack(tensor_chunk).sum(dim=0) for tensor_chunk in tensor_chunks] + ) def simulate_logical_to_physical( @@ -196,12 +229,17 @@ def simulate_logical_to_physical( for layer_id in range(num_layer): for logical_expert_id in range(num_logical_expert): - all_physical_expert_ids = ExpertLocationMetadata.logical_to_all_physical_raw( - logical_to_all_physical_map, layer_id, logical_expert_id + all_physical_expert_ids = ( + ExpertLocationMetadata.logical_to_all_physical_raw( + logical_to_all_physical_map, layer_id, logical_expert_id + ) ) for physical_expert_id in all_physical_expert_ids: - physical_count_of_whatever[:, layer_id, physical_expert_id] += \ - logical_count_of_whatever[:, layer_id, logical_expert_id] / len(all_physical_expert_ids) + physical_count_of_whatever[ + :, layer_id, physical_expert_id + ] += logical_count_of_whatever[:, layer_id, logical_expert_id] / len( + all_physical_expert_ids + ) return physical_count_of_whatever @@ -224,10 +262,16 @@ def compute_utilization_rate( ): """output: utilization_rate (num_batch, num_layer)""" gpu_physical_count_of_batch = gpu_physical_count_of_batch.float() - max_gpu_physical_count = einops.reduce(gpu_physical_count_of_batch, - "num_batch num_layer num_gpu -> num_batch num_layer", 'max') - avg_gpu_physical_count = einops.reduce(gpu_physical_count_of_batch, - "num_batch num_layer num_gpu -> num_batch num_layer", 'mean') + max_gpu_physical_count = einops.reduce( + gpu_physical_count_of_batch, + "num_batch num_layer num_gpu -> num_batch num_layer", + "max", + ) + avg_gpu_physical_count = einops.reduce( + gpu_physical_count_of_batch, + "num_batch num_layer num_gpu -> num_batch num_layer", + "mean", + ) return (avg_gpu_physical_count + 1e-5) / (max_gpu_physical_count + 1e-5) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index deb158a5115..ef43747edb9 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -8,6 +8,7 @@ from typing import Any, List, Optional, Type import torch + from sglang.srt.managers.expert_location import ExpertLocationMetadata from sglang.srt.server_args import ServerArgs from sglang.srt.utils import Withable, get_bool_env_var @@ -85,7 +86,9 @@ def _on_hook(self, hook_name: str, **kwargs): def _reset(self): """Reset the expert distribution recorder.""" logger.info("Resetting ExpertDistributionRecorder...") - assert self._current_layer_idx.value is None, f"{self._current_layer_idx.value=}" + assert ( + self._current_layer_idx.value is None + ), f"{self._current_layer_idx.value=}" for gatherer in self._single_pass_gatherers.values(): gatherer.reset() self._accumulator.reset() @@ -96,7 +99,9 @@ def start_record(self): logger.warning( "SGLang server is already recording expert ids. Did you forget to dump the expert ids recorded so far by sending requests to the `/stop_expert_distribution_record` and `/dump_expert_distribution_record` endpoints?" ) - assert self._server_args.disable_overlap_schedule, "ExpertDistributionRecorder needs disable_overlap_schedule currently (will implement this later)" + assert ( + self._server_args.disable_overlap_schedule + ), "ExpertDistributionRecorder needs disable_overlap_schedule currently (will implement this later)" self._reset() self._recording = True @@ -191,8 +196,8 @@ def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor): torch.cuda.synchronize() num_recv_tokens_per_expert_list = [ - 0 - ] * self._expert_location_metadata.num_local_physical_experts + 0 + ] * self._expert_location_metadata.num_local_physical_experts for token_record in topk_ids_list: for global_physical_expert_idx in token_record: local_physical_expert_idx = ( diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 71747fec177..892aa04d06f 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -5,6 +5,7 @@ from typing import List, Optional import torch + from sglang.srt.configs.model_config import ModelConfig from sglang.srt.managers import deepseek_eplb from sglang.srt.model_loader import get_model_architecture @@ -33,8 +34,8 @@ def init_trivial(server_args: ServerArgs): num_logical_experts = model_config_for_expert_location.num_logical_experts physical_to_logical_map = ( - torch.arange(0, num_physical_experts).repeat(num_layers, 1) - % num_logical_experts + torch.arange(0, num_physical_experts).repeat(num_layers, 1) + % num_logical_experts ) return ExpertLocationMetadata.init_by_mapping( @@ -43,17 +44,16 @@ def init_trivial(server_args: ServerArgs): ) @staticmethod - def init_by_mapping( - server_args: ServerArgs, physical_to_logical_map - ): + def init_by_mapping(server_args: ServerArgs, physical_to_logical_map): if not isinstance(physical_to_logical_map, torch.Tensor): physical_to_logical_map = torch.tensor(physical_to_logical_map) common = ExpertLocationMetadata._init_common(server_args) model_config_for_expert_location = common["model_config_for_expert_location"] - logical_to_all_physical_map = _compute_logical_to_all_physical_map(physical_to_logical_map, - num_logical_experts=model_config_for_expert_location.num_logical_experts - ) + logical_to_all_physical_map = _compute_logical_to_all_physical_map( + physical_to_logical_map, + num_logical_experts=model_config_for_expert_location.num_logical_experts, + ) return ExpertLocationMetadata( num_layers=model_config_for_expert_location.num_layers, @@ -104,8 +104,8 @@ def _init_common(server_args: ServerArgs): ) num_physical_experts = ( - model_config_for_expert_location.num_logical_experts - + server_args.ep_num_redundant_experts + model_config_for_expert_location.num_logical_experts + + server_args.ep_num_redundant_experts ) # TODO consider case when DP attention is disabled and DP > 1 world_size = server_args.tp_size @@ -153,7 +153,7 @@ def physical_to_local_physical(self, global_physical_expert_index: int): return global_physical_expert_index % self.num_local_physical_experts def logical_to_all_physical( - self, layer_id: int, logical_expert_id: int + self, layer_id: int, logical_expert_id: int ) -> List[int]: return self.logical_to_all_physical_raw( self.logical_to_all_physical_map, layer_id, logical_expert_id @@ -161,7 +161,7 @@ def logical_to_all_physical( @staticmethod def logical_to_all_physical_raw( - logical_to_all_physical_map, layer_id: int, logical_expert_id: int + logical_to_all_physical_map, layer_id: int, logical_expert_id: int ) -> List[int]: return [ physical_expert_id @@ -180,31 +180,44 @@ def debug_str(self): ) -def _compute_logical_to_all_physical_map(physical_to_logical_map: torch.Tensor, num_logical_experts: int): +def _compute_logical_to_all_physical_map( + physical_to_logical_map: torch.Tensor, num_logical_experts: int +): # This is rarely called, so we use for loops for maximum clarity num_layers, num_physical_experts = physical_to_logical_map.shape - logical_to_all_physical_map = [[[] for _ in range(num_logical_experts)] for _ in range(num_layers)] + logical_to_all_physical_map = [ + [[] for _ in range(num_logical_experts)] for _ in range(num_layers) + ] for layer_id in range(num_layers): for physical_expert_id in range(num_physical_experts): - logical_expert_id = physical_to_logical_map[layer_id, physical_expert_id].item() - logical_to_all_physical_map[layer_id][logical_expert_id].append(physical_expert_id) + logical_expert_id = physical_to_logical_map[ + layer_id, physical_expert_id + ].item() + logical_to_all_physical_map[layer_id][logical_expert_id].append( + physical_expert_id + ) - logical_to_all_physical_map = _pad_nested_array(logical_to_all_physical_map, pad_value=-1) + logical_to_all_physical_map = _pad_nested_array( + logical_to_all_physical_map, pad_value=-1 + ) return torch.tensor(logical_to_all_physical_map) def _pad_nested_array(arr, pad_value): max_len = max(len(inner) for outer in arr for inner in outer) - padded = [[inner + [pad_value] * (max_len - len(inner)) for inner in outer] for outer in arr] + padded = [ + [inner + [pad_value] * (max_len - len(inner)) for inner in outer] + for outer in arr + ] return padded def _compute_logical_to_rank_dispatch_physical_map( - logical_to_all_physical_map: torch.Tensor, - num_gpus: int, + logical_to_all_physical_map: torch.Tensor, + num_gpus: int, ): # TODO maybe improve this algorithm (e.g. ensure it is really balanced) # This is rarely called, so we use for loops for maximum clarity diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index efa5e3fa041..f26792cb16a 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -32,7 +32,7 @@ get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, parallel_state, - tensor_model_parallel_all_reduce, get_tensor_model_parallel_rank, + tensor_model_parallel_all_reduce, ) from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.dp_attention import ( @@ -76,7 +76,10 @@ ExpertLocationMetadata, ModelConfigForExpertLocation, ) -from sglang.srt.managers.schedule_batch import global_server_args_dict, get_global_expert_location_metadata +from sglang.srt.managers.schedule_batch import ( + get_global_expert_location_metadata, + global_server_args_dict, +) from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import DeepEPMode, add_prefix, is_cuda, is_hip @@ -335,7 +338,9 @@ def forward_deepep( topk_group=self.topk_group, num_expert_group=self.num_expert_group, correction_bias=self.correction_bias, - expert_logical_to_rank_dispatch_physical_map=get_global_expert_location_metadata().logical_to_rank_dispatch_physical_map[self.tp_rank, self.layer_id, :], + expert_logical_to_rank_dispatch_physical_map=get_global_expert_location_metadata().logical_to_rank_dispatch_physical_map[ + self.tp_rank, self.layer_id, : + ], ) # print(f"hi [{get_tensor_model_parallel_rank()}, {self.__class__.__name__}] forward_deepep after-select_experts " # f"{self.layer_id=} {topk_weights=} {topk_idx=} ") diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index 34318a796a0..cfe3ef8c320 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -3,8 +3,9 @@ import unittest from pathlib import Path -import sglang as sgl import torch + +import sglang as sgl from sglang.srt.managers.expert_distribution_storage import ExpertDistributionStorage from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( @@ -18,7 +19,7 @@ # DeepSeek-Coder-V2-Lite-Instruct _NUM_ROUTED_EXPERTS = 64 _NUM_HIDDEN_LAYERS = 27 -_REF_OUTPUT = [', 4+4=8,', ', four plus four is eight, eight'] +_REF_OUTPUT = [", 4+4=8,", ", four plus four is eight, eight"] class TestEPLB(CustomTestCase): @@ -117,7 +118,9 @@ def test_eplb_init_expert_location_and_save_expert_distribution(self): print( f"Action: start engine to check automatically loading from storage dir" ) - engine = sgl.Engine(**engine_kwargs, eplb_storage_dir=eplb_storage_dir_a, port=22000) + engine = sgl.Engine( + **engine_kwargs, eplb_storage_dir=eplb_storage_dir_a, port=22000 + ) self._assert_behavior(engine, "not_equal_trivial") print(f"Action: shutdown engine") engine.shutdown() @@ -141,12 +144,18 @@ def test_nontrivial_location(self): offset = 3 physical_to_logical_map = ( - (offset + torch.arange(0, _NUM_ROUTED_EXPERTS + ep_num_redundant_experts).repeat(_NUM_HIDDEN_LAYERS, 1)) - % _NUM_ROUTED_EXPERTS + offset + + torch.arange(0, _NUM_ROUTED_EXPERTS + ep_num_redundant_experts).repeat( + _NUM_HIDDEN_LAYERS, 1 + ) + ) % _NUM_ROUTED_EXPERTS + init_expert_location = dict( + physical_to_logical_map=physical_to_logical_map.tolist() ) - init_expert_location = dict(physical_to_logical_map=physical_to_logical_map.tolist()) - engine = sgl.Engine(**engine_kwargs, init_expert_location=json.dumps(init_expert_location)) + engine = sgl.Engine( + **engine_kwargs, init_expert_location=json.dumps(init_expert_location) + ) self._assert_behavior(engine, physical_to_logical_map[0].tolist()) engine.shutdown() del engine @@ -171,9 +180,7 @@ def test_trivial_with_redundant_experts(self): engine.shutdown() del engine - def _assert_behavior( - self, engine: sgl.Engine, expect_physical_to_local_map - ): + def _assert_behavior(self, engine: sgl.Engine, expect_physical_to_local_map): actual_output = self._engine_generate(engine) self.assertEqual(actual_output, _REF_OUTPUT) @@ -183,12 +190,16 @@ def _assert_behavior( physical_to_logical_map_layer_0 = physical_to_logical_map[0, :].tolist() print(f"{physical_to_logical_map_layer_0=}") - trivial_expert_locations = _compute_trivial_expert_locations(engine.server_args.ep_num_redundant_experts) + trivial_expert_locations = _compute_trivial_expert_locations( + engine.server_args.ep_num_redundant_experts + ) if expect_physical_to_local_map == "equal_trivial": self.assertEqual(physical_to_logical_map_layer_0, trivial_expert_locations) elif expect_physical_to_local_map == "not_equal_trivial": - self.assertNotEqual(physical_to_logical_map_layer_0, trivial_expert_locations) + self.assertNotEqual( + physical_to_logical_map_layer_0, trivial_expert_locations + ) else: self.assertEqual( physical_to_logical_map_layer_0, expect_physical_to_local_map @@ -210,7 +221,10 @@ def _engine_flush_cache(self, engine: sgl.Engine): def _compute_trivial_expert_locations(ep_num_redundant_experts: int): - return list(x % _NUM_ROUTED_EXPERTS for x in range(_NUM_ROUTED_EXPERTS + ep_num_redundant_experts)) + return list( + x % _NUM_ROUTED_EXPERTS + for x in range(_NUM_ROUTED_EXPERTS + ep_num_redundant_experts) + ) if __name__ == "__main__": From 1554502762159c0c4ed33742fb263f331e15723a Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 13 Apr 2025 12:52:35 +0800 Subject: [PATCH 0684/1089] more --- python/sglang/srt/layers/moe/ep_moe/layer.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index c4708194d40..9d485014a40 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -1038,10 +1038,6 @@ def forward_deepgemm_masked( ): assert self.quant_method is not None assert self.activation == "silu" - # NOTE HACK - # assert ( - # hidden_states_fp8[0].size(0) % 4 == 0 - # ), f"TMA alignment error: {hidden_states_fp8[0].size(0)}" # GroupGemm-0 num_groups, m, k = hidden_states_fp8[0].size() From 871034d9305f45a503ccc6358cfaa45037908db0 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 13 Apr 2025 14:30:10 +0800 Subject: [PATCH 0685/1089] more --- python/sglang/srt/managers/expert_distribution.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index ef43747edb9..ee0f54ba58f 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -8,7 +8,6 @@ from typing import Any, List, Optional, Type import torch - from sglang.srt.managers.expert_location import ExpertLocationMetadata from sglang.srt.server_args import ServerArgs from sglang.srt.utils import Withable, get_bool_env_var @@ -73,6 +72,9 @@ def on_deepep_dispatch_normal(self, num_recv_tokens_per_expert_list: List[int]): num_recv_tokens_per_expert_list=num_recv_tokens_per_expert_list, ) + def on_deepep_dispatch_low_latency(self, recv_count: torch.Tensor): + self._on_hook("on_deepep_dispatch_low_latency", recv_count=recv_count) + def _on_hook(self, hook_name: str, **kwargs): if not self._recording: return @@ -196,8 +198,8 @@ def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor): torch.cuda.synchronize() num_recv_tokens_per_expert_list = [ - 0 - ] * self._expert_location_metadata.num_local_physical_experts + 0 + ] * self._expert_location_metadata.num_local_physical_experts for token_record in topk_ids_list: for global_physical_expert_idx in token_record: local_physical_expert_idx = ( From a4948b03745d35227764f8d2f5bcdfcc46bb100d Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 13 Apr 2025 14:30:28 +0800 Subject: [PATCH 0686/1089] more --- python/sglang/srt/managers/expert_distribution.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index ee0f54ba58f..b0ca2d82911 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -157,6 +157,9 @@ def on_deepep_dispatch_normal( ): pass + def on_deepep_dispatch_low_latency(self, layer_idx: int, recv_count: torch.Tensor): + pass + def reset(self): raise NotImplementedError From 9869fb3f784dbd6caf41ff7d03ef32241362dcf9 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 13 Apr 2025 14:31:33 +0800 Subject: [PATCH 0687/1089] more --- python/sglang/srt/managers/expert_distribution.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index b0ca2d82911..5d369991853 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -227,7 +227,11 @@ def on_deepep_dispatch_normal( # e.g. use naive tensor copying # need to consider CUDA graph, e.g. add initialization and after-end class _DeepepLowLatencySinglePassGatherer(_SinglePassGatherer): - pass + def __init__(self): + self._data = torch.zeros((num_layers, num_local_physical_experts), dtype=torch.int) + + def on_deepep_dispatch_low_latency(self, layer_idx: int, recv_count: torch.Tensor): + self._data[layer_idx, :] = packed_recv_count # --------------------------------------- Accumulator ----------------------------------------- From 217236775cae4c77f0ecff08af8c3d8370de54d7 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 13 Apr 2025 14:31:50 +0800 Subject: [PATCH 0688/1089] more --- python/sglang/srt/managers/expert_distribution.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 5d369991853..e1999e5506b 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -227,8 +227,13 @@ def on_deepep_dispatch_normal( # e.g. use naive tensor copying # need to consider CUDA graph, e.g. add initialization and after-end class _DeepepLowLatencySinglePassGatherer(_SinglePassGatherer): - def __init__(self): - self._data = torch.zeros((num_layers, num_local_physical_experts), dtype=torch.int) + def __init__(self, expert_location_metadata: "ExpertLocationMetadata"): + super().__init__(expert_location_metadata) + self._data = torch.zeros( + (expert_location_metadata.num_layers, expert_location_metadata.num_local_physical_experts), + dtype=torch.int, + device=TODO, + ) def on_deepep_dispatch_low_latency(self, layer_idx: int, recv_count: torch.Tensor): self._data[layer_idx, :] = packed_recv_count From b75ba01fca35f02e0a5fba203209c5ada72a3976 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 13 Apr 2025 14:32:26 +0800 Subject: [PATCH 0689/1089] more --- python/sglang/srt/managers/expert_distribution.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index e1999e5506b..038ed72be56 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -236,7 +236,14 @@ def __init__(self, expert_location_metadata: "ExpertLocationMetadata"): ) def on_deepep_dispatch_low_latency(self, layer_idx: int, recv_count: torch.Tensor): - self._data[layer_idx, :] = packed_recv_count + # Most naive implementation, can optimize later + self._data[layer_idx, :] = recv_count + + def reset(self): + self._data[...] = 0 + + def collect(self) -> torch.Tensor: + return self._data.clone() # --------------------------------------- Accumulator ----------------------------------------- From 92a69be0208ce0b1369ae9ea4a8ee9c26e6fecd7 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 13 Apr 2025 14:33:09 +0800 Subject: [PATCH 0690/1089] more --- python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index 9a81d9d16ec..539da92f853 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -426,6 +426,8 @@ def dispatch_b( ): hook() if self.return_recv_hook else event.current_stream_wait() + expert_distribution_recorder.on_deepep_dispatch_low_latency(masked_m) + reorder_topk_ids = seg_indptr = None return ( From 1af69736372b9311fff1b104692f95a40b9d4d04 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 13 Apr 2025 14:34:38 +0800 Subject: [PATCH 0691/1089] more --- python/sglang/srt/managers/expert_distribution.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 038ed72be56..93fd6afc5d7 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -142,8 +142,13 @@ def init_new( server_args: ServerArgs, expert_location_metadata: "ExpertLocationMetadata" ) -> "_SinglePassGatherer": if server_args.enable_deepep_moe: - # TODO DeepEP low latency - return _DeepepNormalSinglePassGatherer(expert_location_metadata) + # `auto` has many restrictions now, so we lower the priority to implement low-latency capturing for auto + if server_args.deepep_mode in ["normal", "auto"]: + return _DeepepNormalSinglePassGatherer(expert_location_metadata) + elif server_args.deepep_mode == "low_latency": + return _DeepepLowLatencySinglePassGatherer(expert_location_metadata) + else: + raise NotImplementedError return _SelectExpertsSinglePassGatherer(expert_location_metadata) def __init__(self, expert_location_metadata: "ExpertLocationMetadata"): From 6edd623105992a3f1d231e04e002e7890e52b62d Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 13 Apr 2025 14:35:35 +0800 Subject: [PATCH 0692/1089] more --- python/sglang/srt/managers/expert_distribution.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 93fd6afc5d7..73c36ce3430 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -237,7 +237,7 @@ def __init__(self, expert_location_metadata: "ExpertLocationMetadata"): self._data = torch.zeros( (expert_location_metadata.num_layers, expert_location_metadata.num_local_physical_experts), dtype=torch.int, - device=TODO, + device="cuda", ) def on_deepep_dispatch_low_latency(self, layer_idx: int, recv_count: torch.Tensor): From aee361e2d2c02709c4cbd93e2faa6a2f44079664 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 13 Apr 2025 15:11:04 +0800 Subject: [PATCH 0693/1089] more --- python/sglang/srt/managers/eplb_simulator.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/managers/eplb_simulator.py b/python/sglang/srt/managers/eplb_simulator.py index 6ccf37d164b..1de6e77f17e 100644 --- a/python/sglang/srt/managers/eplb_simulator.py +++ b/python/sglang/srt/managers/eplb_simulator.py @@ -7,13 +7,12 @@ import einops import polars as pl import torch -from tqdm.auto import tqdm - from sglang.srt.managers import deepseek_eplb from sglang.srt.managers.expert_location import ( ExpertLocationMetadata, ModelConfigForExpertLocation, ) +from tqdm.auto import tqdm @dataclass @@ -109,8 +108,14 @@ def scan_combinations( enable_expert_location_by_eplb=enable_expert_location_by_eplb, ) for ep_num_redundant_experts in [0, 32, 64] - for nnodes in [4] - for chunked_prefill_size_per_gpu in [1024, 4096, 8192, 16384] + + for nnodes in [4, 8, 9] + # TODO rename this for decode + for chunked_prefill_size_per_gpu in [64, 128] + + # for nnodes in [4] + # for chunked_prefill_size_per_gpu in [1024, 4096, 8192, 16384] + for enable_expert_location_by_eplb in [ *([False] if ep_num_redundant_experts == 0 else []), True, @@ -236,7 +241,7 @@ def simulate_logical_to_physical( ) for physical_expert_id in all_physical_expert_ids: physical_count_of_whatever[ - :, layer_id, physical_expert_id + :, layer_id, physical_expert_id ] += logical_count_of_whatever[:, layer_id, logical_expert_id] / len( all_physical_expert_ids ) From 8b38dab3a3dc5c04a91adcf11863015d20a28860 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 13 Apr 2025 15:15:32 +0800 Subject: [PATCH 0694/1089] more --- python/sglang/srt/managers/expert_distribution.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 73c36ce3430..1c00a4676a8 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -228,9 +228,6 @@ def on_deepep_dispatch_normal( self._on_layer_data(layer_idx, num_recv_tokens_per_expert_list) -# TODO Wait for LowLatency DeepEP merging -# e.g. use naive tensor copying -# need to consider CUDA graph, e.g. add initialization and after-end class _DeepepLowLatencySinglePassGatherer(_SinglePassGatherer): def __init__(self, expert_location_metadata: "ExpertLocationMetadata"): super().__init__(expert_location_metadata) @@ -245,9 +242,11 @@ def on_deepep_dispatch_low_latency(self, layer_idx: int, recv_count: torch.Tenso self._data[layer_idx, :] = recv_count def reset(self): + print(f"hi _DeepepLowLatencySinglePassGatherer.reset") self._data[...] = 0 def collect(self) -> torch.Tensor: + print(f"hi _DeepepLowLatencySinglePassGatherer.collect {self._data.sum()=}") return self._data.clone() From 836c56aa179db22cb0e0ffae1f48eeece0da7aa9 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 13 Apr 2025 15:17:16 +0800 Subject: [PATCH 0695/1089] more --- python/sglang/srt/managers/expert_distribution.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 1c00a4676a8..e6167b815ac 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -23,6 +23,8 @@ class _ExpertDistributionRecorder: def __init__(self): self._recording = False + # TODO improve API + self._enable_in_cuda_graph = get_bool_env_var("SGLANG_EXPERT_DISTRIBUTION_RECORDER_ENABLE_IN_CUDA_GRAPH") self._current_layer_idx = Withable() self._current_debug_name = Withable() From 477a4dba41af1623e7106df697f6c89887021282 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 13 Apr 2025 15:17:53 +0800 Subject: [PATCH 0696/1089] more --- python/sglang/srt/managers/expert_distribution.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index e6167b815ac..68f44698bcd 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -123,6 +123,10 @@ def dump_record(self): self._reset() return output + @property + def _enable(self): + return self._recording or (TODO and self._enable_in_cuda_graph) + expert_distribution_recorder = _ExpertDistributionRecorder() From 3e1d92d929194873cf104d8aef2561e07676896e Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 13 Apr 2025 15:18:28 +0800 Subject: [PATCH 0697/1089] more --- python/sglang/srt/managers/expert_distribution.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 68f44698bcd..2ad7b6f0bc2 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -125,7 +125,7 @@ def dump_record(self): @property def _enable(self): - return self._recording or (TODO and self._enable_in_cuda_graph) + return self._recording or (self._enable_in_cuda_graph and torch.cuda.is_current_stream_capturing()) expert_distribution_recorder = _ExpertDistributionRecorder() From 83657d45d2963d1926fa938f572086f18b55b44d Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 13 Apr 2025 15:18:48 +0800 Subject: [PATCH 0698/1089] more --- python/sglang/srt/managers/expert_distribution.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 2ad7b6f0bc2..1cb16040d04 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -78,7 +78,7 @@ def on_deepep_dispatch_low_latency(self, recv_count: torch.Tensor): self._on_hook("on_deepep_dispatch_low_latency", recv_count=recv_count) def _on_hook(self, hook_name: str, **kwargs): - if not self._recording: + if not self._enable: return gatherer = self._single_pass_gatherers[ self._accumulator.get_single_pass_gatherer_key( From 2f0c75779bdc2f79454332d502d52c41542c1bac Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 13 Apr 2025 15:19:08 +0800 Subject: [PATCH 0699/1089] more --- python/sglang/srt/managers/expert_distribution.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 1cb16040d04..470306f1020 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -78,7 +78,7 @@ def on_deepep_dispatch_low_latency(self, recv_count: torch.Tensor): self._on_hook("on_deepep_dispatch_low_latency", recv_count=recv_count) def _on_hook(self, hook_name: str, **kwargs): - if not self._enable: + if not (self._recording or (self._enable_in_cuda_graph and torch.cuda.is_current_stream_capturing())): return gatherer = self._single_pass_gatherers[ self._accumulator.get_single_pass_gatherer_key( @@ -123,10 +123,6 @@ def dump_record(self): self._reset() return output - @property - def _enable(self): - return self._recording or (self._enable_in_cuda_graph and torch.cuda.is_current_stream_capturing()) - expert_distribution_recorder = _ExpertDistributionRecorder() From 70ef01d8a8e0f71e5534fc304a3cffb9cf4d5be4 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 13 Apr 2025 15:23:53 +0800 Subject: [PATCH 0700/1089] more --- python/sglang/srt/managers/expert_distribution.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 470306f1020..941999c654a 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -50,11 +50,18 @@ def with_debug_name(self, debug_name): @contextmanager def with_forward_pass(self, forward_pass_id: int): + self._on_forward_pass_start() try: yield finally: self._on_forward_pass_end(forward_pass_id) + def _on_forward_pass_start(self): + if not self._recording: + return + for gatherer_key, gatherer in self._single_pass_gatherers.items(): + gatherer.reset() + def _on_forward_pass_end(self, forward_pass_id: int): if not self._recording: return @@ -63,7 +70,6 @@ def _on_forward_pass_end(self, forward_pass_id: int): self._accumulator.append( forward_pass_id, gatherer_key, single_pass_physical_count ) - gatherer.reset() def on_select_experts(self, topk_ids: torch.Tensor): self._on_hook("on_select_experts", topk_ids=topk_ids) @@ -244,11 +250,9 @@ def on_deepep_dispatch_low_latency(self, layer_idx: int, recv_count: torch.Tenso self._data[layer_idx, :] = recv_count def reset(self): - print(f"hi _DeepepLowLatencySinglePassGatherer.reset") self._data[...] = 0 def collect(self) -> torch.Tensor: - print(f"hi _DeepepLowLatencySinglePassGatherer.collect {self._data.sum()=}") return self._data.clone() From 56d4a91b33914e71a389b0eeb1cb1f6a3cb73bc5 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 13 Apr 2025 17:58:36 +0800 Subject: [PATCH 0701/1089] more --- python/sglang/srt/managers/eplb_simulator.py | 47 +++++++++++--------- 1 file changed, 25 insertions(+), 22 deletions(-) diff --git a/python/sglang/srt/managers/eplb_simulator.py b/python/sglang/srt/managers/eplb_simulator.py index 1de6e77f17e..d4a354b8e60 100644 --- a/python/sglang/srt/managers/eplb_simulator.py +++ b/python/sglang/srt/managers/eplb_simulator.py @@ -31,9 +31,9 @@ class MyExpertLocationMetadata: @staticmethod def init_by_eplb( - server_args: MyServerArgs, - logical_count: torch.Tensor, - num_physical_experts: int, + server_args: MyServerArgs, + logical_count: torch.Tensor, + num_physical_experts: int, ): model_config_for_expert_location = _MY_MODEL_CONFIG_FOR_EXPERT_LOCATION @@ -66,10 +66,10 @@ def read_physical_count_of_forward_pass(dir_data: Path): for path in tqdm(list(dir_data.glob("*.pt"))): for record in torch.load(path, weights_only=True): assert ( - physical_count_of_forward_pass_id_and_rank[ - record["forward_pass_id"] - ].get(record["rank"]) - is None + physical_count_of_forward_pass_id_and_rank[ + record["forward_pass_id"] + ].get(record["rank"]) + is None ) physical_count_of_forward_pass_id_and_rank[record["forward_pass_id"]][ record["rank"] @@ -78,7 +78,7 @@ def read_physical_count_of_forward_pass(dir_data: Path): items = [] for forward_pass_id, physical_count_of_rank in sorted( - physical_count_of_forward_pass_id_and_rank.items() + physical_count_of_forward_pass_id_and_rank.items() ): physical_count_of_rank_tensor = torch.cat( [ @@ -96,7 +96,7 @@ def read_physical_count_of_forward_pass(dir_data: Path): def scan_combinations( - logical_count_of_seq: torch.Tensor, + logical_count_of_seq: torch.Tensor, ): server_args_list = [ *[ @@ -109,7 +109,10 @@ def scan_combinations( ) for ep_num_redundant_experts in [0, 32, 64] - for nnodes in [4, 8, 9] + for nnodes in [ + 4, 8, + *([9] if ep_num_redundant_experts == 32 else []), + ] # TODO rename this for decode for chunked_prefill_size_per_gpu in [64, 128] @@ -153,8 +156,8 @@ def analyze_actual_utilization_rate(dir_data: Path, num_gpu: int): def simulate_execution( - logical_count_of_seq: torch.Tensor, - server_args: MyServerArgs, + logical_count_of_seq: torch.Tensor, + server_args: MyServerArgs, ): model_config_for_expert_location = _MY_MODEL_CONFIG_FOR_EXPERT_LOCATION @@ -166,8 +169,8 @@ def simulate_execution( if server_args.enable_expert_location_by_eplb: num_physical_expert = ( - model_config_for_expert_location.num_logical_experts - + server_args.ep_num_redundant_experts + model_config_for_expert_location.num_logical_experts + + server_args.ep_num_redundant_experts ) expert_location_metadata = MyExpertLocationMetadata.init_by_eplb( server_args, @@ -205,8 +208,8 @@ def simulate_execution( def simulate_batching( - logical_count_of_seq: torch.Tensor, # (num_seq, num_layer, num_logical_expert) - chunked_prefill_size: int, + logical_count_of_seq: torch.Tensor, # (num_seq, num_layer, num_logical_expert) + chunked_prefill_size: int, ) -> torch.Tensor: """output: (num_batch, num_layer, num_logical_expert)""" tensor_chunks = chunker( @@ -220,9 +223,9 @@ def simulate_batching( def simulate_logical_to_physical( - logical_count_of_whatever: torch.Tensor, # (*, num_layer, num_logical_expert) - logical_to_all_physical_map: torch.Tensor, # (num_layer, num_logical_experts, X) - num_physical_expert: int, + logical_count_of_whatever: torch.Tensor, # (*, num_layer, num_logical_expert) + logical_to_all_physical_map: torch.Tensor, # (num_layer, num_logical_experts, X) + num_physical_expert: int, ): """output: (*, num_layer, num_physical_expert)""" num_whatever, num_layer, num_logical_expert = logical_count_of_whatever.shape @@ -250,8 +253,8 @@ def simulate_logical_to_physical( def compute_gpu_physical_count( - physical_count_of_whatever: torch.Tensor, # (whatever, num_layer, num_physical_expert) - num_gpu: int, + physical_count_of_whatever: torch.Tensor, # (whatever, num_layer, num_physical_expert) + num_gpu: int, ): """output: gpu_physical_count_of_batch (whatever, num_layer, num_gpu)""" return einops.reduce( @@ -263,7 +266,7 @@ def compute_gpu_physical_count( def compute_utilization_rate( - gpu_physical_count_of_batch: torch.Tensor, # (num_batch, num_layer, num_gpu) + gpu_physical_count_of_batch: torch.Tensor, # (num_batch, num_layer, num_gpu) ): """output: utilization_rate (num_batch, num_layer)""" gpu_physical_count_of_batch = gpu_physical_count_of_batch.float() From 7b8eba9def575b0b6bc18f0402d6f6b0fb23b2d6 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 13 Apr 2025 17:59:36 +0800 Subject: [PATCH 0702/1089] more --- python/sglang/srt/managers/eplb_simulator.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/managers/eplb_simulator.py b/python/sglang/srt/managers/eplb_simulator.py index d4a354b8e60..e31b2655620 100644 --- a/python/sglang/srt/managers/eplb_simulator.py +++ b/python/sglang/srt/managers/eplb_simulator.py @@ -17,7 +17,8 @@ @dataclass class MyServerArgs: - chunked_prefill_size: int + # When prefill, this is equivalent to `chunked_prefill_size` + process_num_tokens_overall: int ep_num_redundant_experts: int nnodes: int tp_size: int @@ -101,7 +102,7 @@ def scan_combinations( server_args_list = [ *[ MyServerArgs( - chunked_prefill_size=chunked_prefill_size_per_gpu * 32, + process_num_tokens_overall=chunked_prefill_size_per_gpu * 32, ep_num_redundant_experts=ep_num_redundant_experts, nnodes=nnodes, tp_size=8 * nnodes, @@ -113,7 +114,6 @@ def scan_combinations( 4, 8, *([9] if ep_num_redundant_experts == 32 else []), ] - # TODO rename this for decode for chunked_prefill_size_per_gpu in [64, 128] # for nnodes in [4] @@ -163,7 +163,7 @@ def simulate_execution( logical_count_of_batch = simulate_batching( logical_count_of_seq=logical_count_of_seq, - chunked_prefill_size=server_args.chunked_prefill_size, + process_num_tokens_overall=server_args.process_num_tokens_overall, ) print(f"{logical_count_of_batch.shape=}") @@ -209,13 +209,13 @@ def simulate_execution( def simulate_batching( logical_count_of_seq: torch.Tensor, # (num_seq, num_layer, num_logical_expert) - chunked_prefill_size: int, + process_num_tokens_overall: int, ) -> torch.Tensor: """output: (num_batch, num_layer, num_logical_expert)""" tensor_chunks = chunker( logical_count_of_seq, state_reducer=lambda count, tensor: count + compute_num_token(tensor).item(), - should_chunk=lambda count: count > chunked_prefill_size, + should_chunk=lambda count: count > process_num_tokens_overall, ) return torch.stack( [torch.stack(tensor_chunk).sum(dim=0) for tensor_chunk in tensor_chunks] From 760875a0fcc40e42003cad33f4f311d4bd355de8 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 13 Apr 2025 18:00:00 +0800 Subject: [PATCH 0703/1089] more --- python/sglang/srt/managers/eplb_simulator.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/managers/eplb_simulator.py b/python/sglang/srt/managers/eplb_simulator.py index e31b2655620..283efc8e399 100644 --- a/python/sglang/srt/managers/eplb_simulator.py +++ b/python/sglang/srt/managers/eplb_simulator.py @@ -18,7 +18,7 @@ @dataclass class MyServerArgs: # When prefill, this is equivalent to `chunked_prefill_size` - process_num_tokens_overall: int + num_tokens_in_batch_overall: int ep_num_redundant_experts: int nnodes: int tp_size: int @@ -102,7 +102,7 @@ def scan_combinations( server_args_list = [ *[ MyServerArgs( - process_num_tokens_overall=chunked_prefill_size_per_gpu * 32, + num_tokens_in_batch_overall=num_tokens_in_batch_per_gpu * 32, ep_num_redundant_experts=ep_num_redundant_experts, nnodes=nnodes, tp_size=8 * nnodes, @@ -114,10 +114,10 @@ def scan_combinations( 4, 8, *([9] if ep_num_redundant_experts == 32 else []), ] - for chunked_prefill_size_per_gpu in [64, 128] + for num_tokens_in_batch_per_gpu in [64, 128] # for nnodes in [4] - # for chunked_prefill_size_per_gpu in [1024, 4096, 8192, 16384] + # for num_tokens_in_batch_per_gpu in [1024, 4096, 8192, 16384] for enable_expert_location_by_eplb in [ *([False] if ep_num_redundant_experts == 0 else []), @@ -163,7 +163,7 @@ def simulate_execution( logical_count_of_batch = simulate_batching( logical_count_of_seq=logical_count_of_seq, - process_num_tokens_overall=server_args.process_num_tokens_overall, + num_tokens_in_batch_overall=server_args.num_tokens_in_batch_overall, ) print(f"{logical_count_of_batch.shape=}") @@ -209,13 +209,13 @@ def simulate_execution( def simulate_batching( logical_count_of_seq: torch.Tensor, # (num_seq, num_layer, num_logical_expert) - process_num_tokens_overall: int, + num_tokens_in_batch_overall: int, ) -> torch.Tensor: """output: (num_batch, num_layer, num_logical_expert)""" tensor_chunks = chunker( logical_count_of_seq, state_reducer=lambda count, tensor: count + compute_num_token(tensor).item(), - should_chunk=lambda count: count > process_num_tokens_overall, + should_chunk=lambda count: count > num_tokens_in_batch_overall, ) return torch.stack( [torch.stack(tensor_chunk).sum(dim=0) for tensor_chunk in tensor_chunks] From ce8b34113585837faf26079c5acd821ba2d224ae Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 13 Apr 2025 18:00:24 +0800 Subject: [PATCH 0704/1089] more --- python/sglang/srt/managers/eplb_simulator.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/eplb_simulator.py b/python/sglang/srt/managers/eplb_simulator.py index 283efc8e399..9f17e7b9921 100644 --- a/python/sglang/srt/managers/eplb_simulator.py +++ b/python/sglang/srt/managers/eplb_simulator.py @@ -99,13 +99,14 @@ def read_physical_count_of_forward_pass(dir_data: Path): def scan_combinations( logical_count_of_seq: torch.Tensor, ): + num_gpu_per_node = 8 server_args_list = [ *[ MyServerArgs( - num_tokens_in_batch_overall=num_tokens_in_batch_per_gpu * 32, + num_tokens_in_batch_overall=num_tokens_in_batch_per_gpu * num_gpu_per_node * nnodes, ep_num_redundant_experts=ep_num_redundant_experts, nnodes=nnodes, - tp_size=8 * nnodes, + tp_size=num_gpu_per_node * nnodes, enable_expert_location_by_eplb=enable_expert_location_by_eplb, ) for ep_num_redundant_experts in [0, 32, 64] From e8047c9c2c9f72236fe5bbe1e3f022f386bdf94d Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 13 Apr 2025 18:01:19 +0800 Subject: [PATCH 0705/1089] more --- python/sglang/srt/managers/eplb_simulator.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/managers/eplb_simulator.py b/python/sglang/srt/managers/eplb_simulator.py index 9f17e7b9921..0f9e126cc29 100644 --- a/python/sglang/srt/managers/eplb_simulator.py +++ b/python/sglang/srt/managers/eplb_simulator.py @@ -130,16 +130,11 @@ def scan_combinations( rows = [] for server_args in server_args_list: print() - mean_utilization_rate = simulate_execution( + info = simulate_execution( logical_count_of_seq=logical_count_of_seq, server_args=server_args ) - print(f"{server_args=} {mean_utilization_rate=:.2f}") - rows.append( - dict( - **dataclasses.asdict(server_args), - mean_utilization_rate=mean_utilization_rate, - ) - ) + print(f"{server_args=} {info=}") + rows.append(dict(**dataclasses.asdict(server_args), **info)) df = pl.DataFrame(rows) return df @@ -205,7 +200,10 @@ def simulate_execution( # NOTE: first 3 layers are MLP, thus those parts are not meaningful mean_utilization_rate = torch.mean(utilization_rate).item() - return mean_utilization_rate + return dict( + mean_utilization_rate=mean_utilization_rate, + num_simulated_batches=logical_count_of_batch.shape[0], + ) def simulate_batching( From 6da4cdbbf44fb82e3d8719a4a802a57cf30a33dc Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 13 Apr 2025 18:06:13 +0800 Subject: [PATCH 0706/1089] more --- python/sglang/srt/managers/eplb_simulator.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/eplb_simulator.py b/python/sglang/srt/managers/eplb_simulator.py index 0f9e126cc29..da9a4e7e592 100644 --- a/python/sglang/srt/managers/eplb_simulator.py +++ b/python/sglang/srt/managers/eplb_simulator.py @@ -60,6 +60,7 @@ def init_by_eplb( num_logical_experts=256, num_groups=8, ) +_MY_MODEL_CONFIG_NUM_EXPERTS_PER_TOK = 8 def read_physical_count_of_forward_pass(dir_data: Path): @@ -283,7 +284,8 @@ def compute_utilization_rate( def compute_num_token(whatever_with_num_layer_and_num_expert: torch.Tensor): - return whatever_with_num_layer_and_num_expert[..., -1, :].sum(dim=-1) + num_token_mul_num_experts = whatever_with_num_layer_and_num_expert[..., -1, :].sum(dim=-1) + return num_token_mul_num_experts / _MY_MODEL_CONFIG_NUM_EXPERTS_PER_TOK def chunker(objects, state_reducer, should_chunk): From 46c44316ccdc91f0173edf07f3f68fb25d0aaf63 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 13 Apr 2025 18:07:32 +0800 Subject: [PATCH 0707/1089] more --- python/sglang/srt/managers/eplb_simulator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/eplb_simulator.py b/python/sglang/srt/managers/eplb_simulator.py index da9a4e7e592..22d69d62964 100644 --- a/python/sglang/srt/managers/eplb_simulator.py +++ b/python/sglang/srt/managers/eplb_simulator.py @@ -215,7 +215,7 @@ def simulate_batching( tensor_chunks = chunker( logical_count_of_seq, state_reducer=lambda count, tensor: count + compute_num_token(tensor).item(), - should_chunk=lambda count: count > num_tokens_in_batch_overall, + should_chunk=lambda count: count >= num_tokens_in_batch_overall, ) return torch.stack( [torch.stack(tensor_chunk).sum(dim=0) for tensor_chunk in tensor_chunks] From a44c382faa0135cd270e7ce6a7aa8d911c20ff29 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 13 Apr 2025 18:09:51 +0800 Subject: [PATCH 0708/1089] more --- python/sglang/srt/managers/eplb_simulator.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/eplb_simulator.py b/python/sglang/srt/managers/eplb_simulator.py index 22d69d62964..efd44acf0d3 100644 --- a/python/sglang/srt/managers/eplb_simulator.py +++ b/python/sglang/srt/managers/eplb_simulator.py @@ -110,14 +110,18 @@ def scan_combinations( tp_size=num_gpu_per_node * nnodes, enable_expert_location_by_eplb=enable_expert_location_by_eplb, ) - for ep_num_redundant_experts in [0, 32, 64] + # decode + for ep_num_redundant_experts in [0, 32] for nnodes in [ - 4, 8, + 4, + *([8] if ep_num_redundant_experts == 0 else []), *([9] if ep_num_redundant_experts == 32 else []), ] for num_tokens_in_batch_per_gpu in [64, 128] + # prefill + # for ep_num_redundant_experts in [0, 32, 64] # for nnodes in [4] # for num_tokens_in_batch_per_gpu in [1024, 4096, 8192, 16384] From fbc272531a7128d47edaf3295bd8bb516a6874be Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 13 Apr 2025 18:12:03 +0800 Subject: [PATCH 0709/1089] more --- python/sglang/srt/managers/eplb_simulator.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/sglang/srt/managers/eplb_simulator.py b/python/sglang/srt/managers/eplb_simulator.py index efd44acf0d3..ec56880fc9e 100644 --- a/python/sglang/srt/managers/eplb_simulator.py +++ b/python/sglang/srt/managers/eplb_simulator.py @@ -114,6 +114,8 @@ def scan_combinations( # decode for ep_num_redundant_experts in [0, 32] for nnodes in [ + 1, + 2, 4, *([8] if ep_num_redundant_experts == 0 else []), *([9] if ep_num_redundant_experts == 32 else []), From 16426fbb4655bf7db5893e9f2fbaed78de04dbb7 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 13 Apr 2025 18:22:37 +0800 Subject: [PATCH 0710/1089] more --- python/sglang/srt/managers/eplb_simulator.py | 24 ++++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/python/sglang/srt/managers/eplb_simulator.py b/python/sglang/srt/managers/eplb_simulator.py index ec56880fc9e..ea4204ed577 100644 --- a/python/sglang/srt/managers/eplb_simulator.py +++ b/python/sglang/srt/managers/eplb_simulator.py @@ -112,20 +112,20 @@ def scan_combinations( ) # decode - for ep_num_redundant_experts in [0, 32] - for nnodes in [ - 1, - 2, - 4, - *([8] if ep_num_redundant_experts == 0 else []), - *([9] if ep_num_redundant_experts == 32 else []), - ] - for num_tokens_in_batch_per_gpu in [64, 128] + # for ep_num_redundant_experts in [0, 32] + # for nnodes in [ + # 1, + # 2, + # 4, + # *([8] if ep_num_redundant_experts == 0 else []), + # *([9] if ep_num_redundant_experts == 32 else []), + # ] + # for num_tokens_in_batch_per_gpu in [64, 128] # prefill - # for ep_num_redundant_experts in [0, 32, 64] - # for nnodes in [4] - # for num_tokens_in_batch_per_gpu in [1024, 4096, 8192, 16384] + for ep_num_redundant_experts in [0, 32, 64] + for nnodes in [1, 2, 4] + for num_tokens_in_batch_per_gpu in [1024, 4096, 8192, 16384] for enable_expert_location_by_eplb in [ *([False] if ep_num_redundant_experts == 0 else []), From 48cc814ae72af02ee3aa3cdcb247fc58b880d5c9 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 13 Apr 2025 18:34:51 +0800 Subject: [PATCH 0711/1089] more --- python/sglang/srt/layers/moe/topk.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 0fd70f88c02..9c95cc8d3bd 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -16,10 +16,9 @@ import torch import torch.nn.functional as F - from sglang.srt.managers.expert_distribution import expert_distribution_recorder from sglang.srt.managers.schedule_batch import global_server_args_dict -from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip +from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip, get_bool_env_var _is_cuda = is_cuda() _is_hip = is_hip() @@ -300,7 +299,10 @@ def select_experts( if expert_logical_to_rank_dispatch_physical_map is not None: # TODO this is inefficient, and I will fuse into existing kernels - topk_ids = expert_logical_to_rank_dispatch_physical_map[topk_ids] + if get_bool_env_var("SGLANG_HACK_EXPERT_LOCATION_DISPATCH_RANDOM"): + topk_ids = expert_logical_to_all_physical_map[topk_ids, TODO] + else: + topk_ids = expert_logical_to_rank_dispatch_physical_map[topk_ids] expert_distribution_recorder.on_select_experts(topk_ids=topk_ids) From 7b609dfa7eaae7f3e6abe6b5e6e47f60ccafd339 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 13 Apr 2025 18:35:08 +0800 Subject: [PATCH 0712/1089] more --- python/sglang/srt/layers/moe/topk.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 9c95cc8d3bd..ceb665793eb 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -300,7 +300,9 @@ def select_experts( if expert_logical_to_rank_dispatch_physical_map is not None: # TODO this is inefficient, and I will fuse into existing kernels if get_bool_env_var("SGLANG_HACK_EXPERT_LOCATION_DISPATCH_RANDOM"): - topk_ids = expert_logical_to_all_physical_map[topk_ids, TODO] + chosen_dispatch_index = torch.randint(0, 65536, topk_ids.shape, + dtype=torch.int32) % expert_logical_to_all_physical_map_num_valid + topk_ids = expert_logical_to_all_physical_map[topk_ids, chosen_dispatch_index] else: topk_ids = expert_logical_to_rank_dispatch_physical_map[topk_ids] From 9965833c6abb5b24e4fc49c0e66983eed7818d4f Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 13 Apr 2025 18:36:12 +0800 Subject: [PATCH 0713/1089] more --- python/sglang/srt/layers/moe/topk.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index ceb665793eb..5d0ed31782d 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -300,12 +300,25 @@ def select_experts( if expert_logical_to_rank_dispatch_physical_map is not None: # TODO this is inefficient, and I will fuse into existing kernels if get_bool_env_var("SGLANG_HACK_EXPERT_LOCATION_DISPATCH_RANDOM"): - chosen_dispatch_index = torch.randint(0, 65536, topk_ids.shape, - dtype=torch.int32) % expert_logical_to_all_physical_map_num_valid - topk_ids = expert_logical_to_all_physical_map[topk_ids, chosen_dispatch_index] + topk_ids = _hack_expert_location_dispatch_random( + topk_ids=topk_ids, + expert_logical_to_all_physical_map=expert_logical_to_all_physical_map, + expert_logical_to_all_physical_map_num_valid=expert_logical_to_all_physical_map_num_valid, + ) else: topk_ids = expert_logical_to_rank_dispatch_physical_map[topk_ids] expert_distribution_recorder.on_select_experts(topk_ids=topk_ids) return topk_weights, topk_ids + + +@torch.compile(dynamic=True, backend=get_compiler_backend()) +def _hack_expert_location_dispatch_random( + topk_ids, + expert_logical_to_all_physical_map, + expert_logical_to_all_physical_map_num_valid, +): + chosen_dispatch_index = (torch.randint(0, 65536, topk_ids.shape, dtype=torch.int32) + % expert_logical_to_all_physical_map_num_valid) + return expert_logical_to_all_physical_map[topk_ids, chosen_dispatch_index] From 462c9b8218f67417b11629ad75ce160c0b4dc14b Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 13 Apr 2025 18:36:25 +0800 Subject: [PATCH 0714/1089] more --- python/sglang/srt/layers/moe/topk.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 5d0ed31782d..0f27cac2a49 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -246,6 +246,8 @@ def select_experts( correction_bias: Optional[torch.Tensor] = None, torch_native: bool = False, expert_logical_to_rank_dispatch_physical_map: Optional[torch.Tensor] = None, + expert_logical_to_all_physical_map=None, + expert_logical_to_all_physical_map_num_valid=None, ): n_share_experts_fusion = 0 if global_server_args_dict["n_share_experts_fusion"] is not None: From 8fe809c304974fb62fe457a5dc1de2440840dcf7 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 13 Apr 2025 18:36:42 +0800 Subject: [PATCH 0715/1089] more --- python/sglang/srt/models/deepseek_v2.py | 75 +++++++++++++------------ 1 file changed, 38 insertions(+), 37 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index f26792cb16a..f5081c2dfdc 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -24,10 +24,6 @@ import torch import torch.nn.functional as F -from torch import nn -from tqdm import tqdm -from transformers import PretrainedConfig - from sglang.srt.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -83,6 +79,9 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import DeepEPMode, add_prefix, is_cuda, is_hip +from torch import nn +from tqdm import tqdm +from transformers import PretrainedConfig _is_hip = is_hip() _is_cuda = is_cuda() @@ -214,8 +213,8 @@ def __init__( self.experts = MoEImpl( num_experts=config.n_routed_experts - + self.n_share_experts_fusion - + global_server_args_dict["ep_num_redundant_experts"], + + self.n_share_experts_fusion + + global_server_args_dict["ep_num_redundant_experts"], top_k=config.num_experts_per_tok + min(self.n_share_experts_fusion, 1), hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, @@ -280,7 +279,7 @@ def __init__( router_topk=self.top_k, permute_fusion=True, num_experts=config.n_routed_experts - + global_server_args_dict["ep_num_redundant_experts"], + + global_server_args_dict["ep_num_redundant_experts"], num_local_experts=config.n_routed_experts // self.tp_size, hidden_size=config.hidden_size, params_dtype=config.torch_dtype, @@ -339,8 +338,10 @@ def forward_deepep( num_expert_group=self.num_expert_group, correction_bias=self.correction_bias, expert_logical_to_rank_dispatch_physical_map=get_global_expert_location_metadata().logical_to_rank_dispatch_physical_map[ - self.tp_rank, self.layer_id, : - ], + self.tp_rank, self.layer_id, : + ], + expert_logical_to_all_physical_map=TODO, + expert_logical_to_all_physical_map_num_valid=TODO, ) # print(f"hi [{get_tensor_model_parallel_rank()}, {self.__class__.__name__}] forward_deepep after-select_experts " # f"{self.layer_id=} {topk_weights=} {topk_idx=} ") @@ -437,7 +438,7 @@ def __init__( self.num_heads = num_heads assert num_heads % attn_tp_size == 0 self.num_local_heads = num_heads // attn_tp_size - self.scaling = self.qk_head_dim**-0.5 + self.scaling = self.qk_head_dim ** -0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings @@ -550,12 +551,12 @@ def forward( kv = self.kv_b_proj(kv_a)[0] kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - k_pe = latent_cache[:, :, self.kv_lora_rank :] + k_pe = latent_cache[:, :, self.kv_lora_rank:] q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - q[..., self.qk_nope_head_dim :] = q_pe + q[..., self.qk_nope_head_dim:] = q_pe k = torch.empty_like(q) k[..., : self.qk_nope_head_dim] = k_nope - k[..., self.qk_nope_head_dim :] = k_pe + k[..., self.qk_nope_head_dim:] = k_pe q = torch.nn.functional.pad(q, [0, 256 - self.qk_head_dim], value=0).view( -1, self.num_local_heads * 256 ) @@ -567,8 +568,8 @@ def forward( ) attn_output = self.attn(q, k, v, forward_batch) attn_output = attn_output.view(-1, self.num_local_heads, 256)[ - ..., : self.v_head_dim - ].reshape(-1, self.num_local_heads * self.v_head_dim) + ..., : self.v_head_dim + ].reshape(-1, self.num_local_heads * self.v_head_dim) output, _ = self.o_proj(attn_output) return output @@ -609,7 +610,7 @@ def __init__( self.num_heads = num_heads assert num_heads % attn_tp_size == 0 self.num_local_heads = num_heads // attn_tp_size - self.scaling = self.qk_head_dim**-0.5 + self.scaling = self.qk_head_dim ** -0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings @@ -796,16 +797,16 @@ def forward_normal( kv = self.kv_b_proj(kv_a)[0] kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) k_nope = kv[..., : self.qk_nope_head_dim] - v = kv[..., self.qk_nope_head_dim :] - k_pe = latent_cache[:, :, self.kv_lora_rank :] + v = kv[..., self.qk_nope_head_dim:] + k_pe = latent_cache[:, :, self.kv_lora_rank:] q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - q[..., self.qk_nope_head_dim :] = q_pe + q[..., self.qk_nope_head_dim:] = q_pe k = torch.empty_like(q) k[..., : self.qk_nope_head_dim] = k_nope - k[..., self.qk_nope_head_dim :] = k_pe + k[..., self.qk_nope_head_dim:] = k_pe latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1) - latent_cache[:, :, self.kv_lora_rank :] = k_pe + latent_cache[:, :, self.kv_lora_rank:] = k_pe # Save latent cache forward_batch.token_to_kv_pool.set_kv_buffer( @@ -858,11 +859,11 @@ def forward_absorb( v_input = self.kv_a_layernorm(v_input.contiguous()).unsqueeze(1) k_input = latent_cache.unsqueeze(1) k_input[..., : self.kv_lora_rank] = v_input - k_pe = k_input[..., self.kv_lora_rank :] + k_pe = k_input[..., self.kv_lora_rank:] q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - q_input[..., self.kv_lora_rank :] = q_pe - k_input[..., self.kv_lora_rank :] = k_pe + q_input[..., self.kv_lora_rank:] = q_pe + k_input[..., self.kv_lora_rank:] = k_pe attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch) attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank) @@ -938,15 +939,15 @@ def forward_absorb_fused_mla_rope( k_input[..., : self.kv_lora_rank] = v_input if not enable_rope_fusion: - k_pe = k_input[..., self.kv_lora_rank :] + k_pe = k_input[..., self.kv_lora_rank:] q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - q_input[..., self.kv_lora_rank :] = q_pe - k_input[..., self.kv_lora_rank :] = k_pe + q_input[..., self.kv_lora_rank:] = q_pe + k_input[..., self.kv_lora_rank:] = k_pe k_pe_output = None else: - k_pe_output = torch.empty_like(k_input[..., self.kv_lora_rank :]) + k_pe_output = torch.empty_like(k_input[..., self.kv_lora_rank:]) - q_input[..., self.kv_lora_rank :] = q_pe + q_input[..., self.kv_lora_rank:] = q_pe # attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch) # Use Fused ROPE with use_rope=OFF. @@ -1003,7 +1004,7 @@ def forward_absorb_fused_mla_rope( ) if enable_rope_fusion: - k_input[..., self.kv_lora_rank :] = k_pe_output + k_input[..., self.kv_lora_rank:] = k_pe_output forward_batch.token_to_kv_pool.set_kv_buffer( self.attn_mqa, forward_batch.out_cache_loc, k_input, None ) @@ -1162,7 +1163,7 @@ def _compute_info(config: PretrainedConfig, layer_id: int, is_nextn: bool): execution_mode = ( _DecoderLayerExecutionMode.MLP_INPUT_ONE if (global_server_args_dict["enable_deepep_moe"] and is_sparse) - or (DeepseekV2DecoderLayer._enable_moe_dense_fully_dp() and not is_sparse) + or (DeepseekV2DecoderLayer._enable_moe_dense_fully_dp() and not is_sparse) else _DecoderLayerExecutionMode.MLP_INPUT_ALL ) return _DecoderLayerInfo(is_sparse=is_sparse, execution_mode=execution_mode) @@ -1568,7 +1569,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): self.config.moe_layer_freq, ), desc=f"Cloning {self.n_share_experts_fusion} " - "replicas of the shared expert into MoE", + "replicas of the shared expert into MoE", ): for num_repeat in range(self.n_share_experts_fusion): for suffix in suffix_list: @@ -1599,11 +1600,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", num_experts=self.config.n_routed_experts - + ( - self.n_share_experts_fusion - if self.n_share_experts_fusion is not None - else 0 - ), + + ( + self.n_share_experts_fusion + if self.n_share_experts_fusion is not None + else 0 + ), ) params_dict = dict(self.named_parameters()) From 865b43d7ce8e9f9b85542b1b11acb2e45d0d3d42 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 13 Apr 2025 18:37:16 +0800 Subject: [PATCH 0716/1089] more --- python/sglang/srt/models/deepseek_v2.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index f5081c2dfdc..863809a47b0 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -328,6 +328,7 @@ def forward_deepep( # router_logits: (num_tokens, n_experts) router_logits = self.gate(hidden_states) shared_output = self._forward_shared_experts(hidden_states) + expert_location_metadata = get_global_expert_location_metadata() topk_weights, topk_idx = select_experts( hidden_states=hidden_states, router_logits=router_logits, @@ -337,11 +338,12 @@ def forward_deepep( topk_group=self.topk_group, num_expert_group=self.num_expert_group, correction_bias=self.correction_bias, - expert_logical_to_rank_dispatch_physical_map=get_global_expert_location_metadata().logical_to_rank_dispatch_physical_map[ - self.tp_rank, self.layer_id, : - ], - expert_logical_to_all_physical_map=TODO, - expert_logical_to_all_physical_map_num_valid=TODO, + expert_logical_to_rank_dispatch_physical_map=expert_location_metadata.logical_to_rank_dispatch_physical_map[ + self.tp_rank, self.layer_id, :], + expert_logical_to_all_physical_map=expert_location_metadata.logical_to_all_physical_map[self.layer_id, + :], + expert_logical_to_all_physical_map_num_valid=expert_location_metadata.logical_to_all_physical_map_num_valid[ + self.layer_id, :], ) # print(f"hi [{get_tensor_model_parallel_rank()}, {self.__class__.__name__}] forward_deepep after-select_experts " # f"{self.layer_id=} {topk_weights=} {topk_idx=} ") From 1d39c4bab517e5c3427205554c2f7fe6eca317e9 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 13 Apr 2025 18:37:49 +0800 Subject: [PATCH 0717/1089] more --- python/sglang/srt/managers/expert_location.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 892aa04d06f..a5537390c9d 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -5,7 +5,6 @@ from typing import List, Optional import torch - from sglang.srt.configs.model_config import ModelConfig from sglang.srt.managers import deepseek_eplb from sglang.srt.model_loader import get_model_architecture @@ -61,6 +60,7 @@ def init_by_mapping(server_args: ServerArgs, physical_to_logical_map): num_local_physical_experts=common["num_local_physical_experts"], physical_to_logical_map=physical_to_logical_map, logical_to_all_physical_map=logical_to_all_physical_map, + logical_to_all_physical_map_num_valid=torch.count_nonzero(logical_to_all_physical_map != -1, dim=-1), logical_to_rank_dispatch_physical_map=_compute_logical_to_rank_dispatch_physical_map( logical_to_all_physical_map, num_gpus=common["world_size"], From a0c7c9ee8dfa345263b9693e48f310bf38174a6b Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 13 Apr 2025 18:38:00 +0800 Subject: [PATCH 0718/1089] more --- python/sglang/srt/managers/expert_location.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index a5537390c9d..438e5f3784b 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -18,6 +18,7 @@ class ExpertLocationMetadata: num_logical_experts: int physical_to_logical_map: torch.Tensor # (layers, num_physical_experts) logical_to_all_physical_map: torch.Tensor # (layers, num_logical_experts, X) + logical_to_all_physical_map_num_valid: torch.Tensor # (layers, num_logical_experts) # (num_gpus, layers, num_logical_experts) logical_to_rank_dispatch_physical_map: torch.Tensor From 37881deac0fbab7f7c6cd2834e645ac9718e9a7a Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 13 Apr 2025 18:38:31 +0800 Subject: [PATCH 0719/1089] more --- python/sglang/srt/managers/expert_location.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 438e5f3784b..5766a6d5695 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -61,7 +61,8 @@ def init_by_mapping(server_args: ServerArgs, physical_to_logical_map): num_local_physical_experts=common["num_local_physical_experts"], physical_to_logical_map=physical_to_logical_map, logical_to_all_physical_map=logical_to_all_physical_map, - logical_to_all_physical_map_num_valid=torch.count_nonzero(logical_to_all_physical_map != -1, dim=-1), + logical_to_all_physical_map_num_valid=_compute_logical_to_all_physical_map_num_valid( + logical_to_all_physical_map), logical_to_rank_dispatch_physical_map=_compute_logical_to_rank_dispatch_physical_map( logical_to_all_physical_map, num_gpus=common["world_size"], @@ -90,7 +91,8 @@ def init_by_eplb(server_args: ServerArgs, logical_count: torch.Tensor): num_logical_experts=model_config_for_expert_location.num_logical_experts, num_local_physical_experts=common["num_local_physical_experts"], physical_to_logical_map=physical_to_logical_map, - logical_to_all_physical_map=logical_to_all_physical_map, + logical_to_all_physical_map_num_valid=_compute_logical_to_all_physical_map_num_valid( + logical_to_all_physical_map), logical_to_rank_dispatch_physical_map=_compute_logical_to_rank_dispatch_physical_map( logical_to_all_physical_map, num_gpus=common["world_size"], @@ -216,6 +218,10 @@ def _pad_nested_array(arr, pad_value): return padded +def _compute_logical_to_all_physical_map_num_valid(logical_to_all_physical_map): + return torch.count_nonzero(logical_to_all_physical_map != -1, dim=-1) + + def _compute_logical_to_rank_dispatch_physical_map( logical_to_all_physical_map: torch.Tensor, num_gpus: int, From 87740afd647269da7690fc704897119d21daea78 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 13 Apr 2025 18:39:11 +0800 Subject: [PATCH 0720/1089] more --- python/sglang/srt/managers/expert_location.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 5766a6d5695..1566cc5ea62 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -91,6 +91,7 @@ def init_by_eplb(server_args: ServerArgs, logical_count: torch.Tensor): num_logical_experts=model_config_for_expert_location.num_logical_experts, num_local_physical_experts=common["num_local_physical_experts"], physical_to_logical_map=physical_to_logical_map, + logical_to_all_physical_map=logical_to_all_physical_map, logical_to_all_physical_map_num_valid=_compute_logical_to_all_physical_map_num_valid( logical_to_all_physical_map), logical_to_rank_dispatch_physical_map=_compute_logical_to_rank_dispatch_physical_map( @@ -132,6 +133,7 @@ def update(self, other: "ExpertLocationMetadata"): for field in [ "physical_to_logical_map", + "logical_to_all_physical_map_num_valid", "logical_to_rank_dispatch_physical_map", ]: # Cannot update address to avoid breaking CUDA graph From a50b0a87c2abcf56424bc6a16cafb93a1777a0f2 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 13 Apr 2025 18:39:50 +0800 Subject: [PATCH 0721/1089] more --- python/sglang/srt/layers/moe/topk.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 0f27cac2a49..1b3ad129797 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -300,8 +300,7 @@ def select_experts( ) if expert_logical_to_rank_dispatch_physical_map is not None: - # TODO this is inefficient, and I will fuse into existing kernels - if get_bool_env_var("SGLANG_HACK_EXPERT_LOCATION_DISPATCH_RANDOM"): + if forward_mode.is_extend(): topk_ids = _hack_expert_location_dispatch_random( topk_ids=topk_ids, expert_logical_to_all_physical_map=expert_logical_to_all_physical_map, From 172e30bc4ea38a13d61aafb237842edc11dbb692 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 13 Apr 2025 18:40:14 +0800 Subject: [PATCH 0722/1089] more --- python/sglang/srt/managers/expert_location.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 1566cc5ea62..034d4a9ebb5 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -145,9 +145,12 @@ def update(self, other: "ExpertLocationMetadata"): setattr(self, field, getattr(other, field)) def to(self, device): - self.logical_to_rank_dispatch_physical_map = ( - self.logical_to_rank_dispatch_physical_map.to(device) - ) + for field in [ + "logical_to_all_physical_map", + "logical_to_all_physical_map_num_valid", + "logical_to_rank_dispatch_physical_map", + ]: + setattr(self, field, getattr(self, field).to(device)) # -------------------------------- usage ------------------------------------ From 9ce9d01c5e68f3b92c3c5422f15726829e27de84 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 13 Apr 2025 18:40:37 +0800 Subject: [PATCH 0723/1089] more --- python/sglang/srt/managers/expert_location.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 034d4a9ebb5..61019825b72 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -142,7 +142,7 @@ def update(self, other: "ExpertLocationMetadata"): for field in [ "logical_to_all_physical_map", ]: - setattr(self, field, getattr(other, field)) + setattr(self, field, getattr(other, field).to(getattr(self, field).device)) def to(self, device): for field in [ From c493293d46152beaea40c56b23d3950b649dee35 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 13 Apr 2025 18:41:01 +0800 Subject: [PATCH 0724/1089] more --- python/sglang/srt/managers/expert_location.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 61019825b72..be7caf23b09 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -140,6 +140,7 @@ def update(self, other: "ExpertLocationMetadata"): getattr(self, field)[...] = getattr(other, field) for field in [ + # TODO maybe make last dim size const to simplify the logic "logical_to_all_physical_map", ]: setattr(self, field, getattr(other, field).to(getattr(self, field).device)) From 19f3165a85e3f2d9c84eec3a1da343f74f644c34 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 13 Apr 2025 18:41:47 +0800 Subject: [PATCH 0725/1089] more --- python/sglang/srt/layers/moe/topk.py | 1 + python/sglang/srt/models/deepseek_v2.py | 1 + 2 files changed, 2 insertions(+) diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 1b3ad129797..30ec8986aac 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -246,6 +246,7 @@ def select_experts( correction_bias: Optional[torch.Tensor] = None, torch_native: bool = False, expert_logical_to_rank_dispatch_physical_map: Optional[torch.Tensor] = None, + forward_mode=None, expert_logical_to_all_physical_map=None, expert_logical_to_all_physical_map_num_valid=None, ): diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 863809a47b0..cd427dbfbeb 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -338,6 +338,7 @@ def forward_deepep( topk_group=self.topk_group, num_expert_group=self.num_expert_group, correction_bias=self.correction_bias, + forward_mode=forward_mode, expert_logical_to_rank_dispatch_physical_map=expert_location_metadata.logical_to_rank_dispatch_physical_map[ self.tp_rank, self.layer_id, :], expert_logical_to_all_physical_map=expert_location_metadata.logical_to_all_physical_map[self.layer_id, From fb8422549b88218388a68aa022ea57bb0707a33c Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 13 Apr 2025 18:42:10 +0800 Subject: [PATCH 0726/1089] more --- python/sglang/srt/layers/moe/topk.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 30ec8986aac..efe377ab94a 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -301,6 +301,7 @@ def select_experts( ) if expert_logical_to_rank_dispatch_physical_map is not None: + # TODO optimize these things later if forward_mode.is_extend(): topk_ids = _hack_expert_location_dispatch_random( topk_ids=topk_ids, From 51d5ec952fdb62a78e8e05e4944cfd681c9f367c Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 13 Apr 2025 18:55:28 +0800 Subject: [PATCH 0727/1089] more --- python/sglang/srt/layers/moe/topk.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index efe377ab94a..567c9b7f245 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -318,10 +318,14 @@ def select_experts( @torch.compile(dynamic=True, backend=get_compiler_backend()) def _hack_expert_location_dispatch_random( - topk_ids, - expert_logical_to_all_physical_map, - expert_logical_to_all_physical_map_num_valid, + topk_ids: torch.Tensor, + expert_logical_to_all_physical_map: torch.Tensor, + expert_logical_to_all_physical_map_num_valid: torch.Tensor, ): + topk_ids_original_shape = topk_ids.shape + topk_ids = topk_ids.flatten() chosen_dispatch_index = (torch.randint(0, 65536, topk_ids.shape, dtype=torch.int32) - % expert_logical_to_all_physical_map_num_valid) - return expert_logical_to_all_physical_map[topk_ids, chosen_dispatch_index] + % expert_logical_to_all_physical_map_num_valid[topk_ids]) + topk_ids = expert_logical_to_all_physical_map[topk_ids, chosen_dispatch_index] + topk_ids = topk_ids.view(topk_ids_original_shape) + return topk_ids From 20db77c0fa75114887b184b2eab28b23d7ed8220 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 13 Apr 2025 18:57:04 +0800 Subject: [PATCH 0728/1089] more --- python/sglang/srt/layers/moe/topk.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 567c9b7f245..071f056169a 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -323,9 +323,12 @@ def _hack_expert_location_dispatch_random( expert_logical_to_all_physical_map_num_valid: torch.Tensor, ): topk_ids_original_shape = topk_ids.shape + device = topk_ids.device topk_ids = topk_ids.flatten() - chosen_dispatch_index = (torch.randint(0, 65536, topk_ids.shape, dtype=torch.int32) + + chosen_dispatch_index = (torch.randint(0, 65536, topk_ids.shape, dtype=torch.int32, device=device) % expert_logical_to_all_physical_map_num_valid[topk_ids]) topk_ids = expert_logical_to_all_physical_map[topk_ids, chosen_dispatch_index] + topk_ids = topk_ids.view(topk_ids_original_shape) return topk_ids From fddb394a0ef254b13e51a83a9708fdbfa21bf564 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 13 Apr 2025 20:43:34 +0800 Subject: [PATCH 0729/1089] more --- python/sglang/srt/managers/eplb_simulator.py | 30 +++++++++++++++----- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/managers/eplb_simulator.py b/python/sglang/srt/managers/eplb_simulator.py index ea4204ed577..98a81f6e641 100644 --- a/python/sglang/srt/managers/eplb_simulator.py +++ b/python/sglang/srt/managers/eplb_simulator.py @@ -1,8 +1,10 @@ # TODO where to put this file? import dataclasses +import json from collections import defaultdict from dataclasses import dataclass from pathlib import Path +from typing import Optional import einops import polars as pl @@ -23,6 +25,7 @@ class MyServerArgs: nnodes: int tp_size: int enable_expert_location_by_eplb: bool + init_expert_location: Optional[str] @dataclass @@ -109,8 +112,11 @@ def scan_combinations( nnodes=nnodes, tp_size=num_gpu_per_node * nnodes, enable_expert_location_by_eplb=enable_expert_location_by_eplb, + init_expert_location=init_expert_location, ) + for init_expert_location in [None, "/host_home/temp_sglang_server2local/1744461420780309768.json"] + # decode # for ep_num_redundant_experts in [0, 32] # for nnodes in [ @@ -123,9 +129,12 @@ def scan_combinations( # for num_tokens_in_batch_per_gpu in [64, 128] # prefill - for ep_num_redundant_experts in [0, 32, 64] - for nnodes in [1, 2, 4] - for num_tokens_in_batch_per_gpu in [1024, 4096, 8192, 16384] + for ep_num_redundant_experts in [0, 32] + for nnodes in [4] + for num_tokens_in_batch_per_gpu in [8192] + # for ep_num_redundant_experts in [0, 32, 64] + # for nnodes in [1, 2, 4] + # for num_tokens_in_batch_per_gpu in [1024, 4096, 8192, 16384] for enable_expert_location_by_eplb in [ *([False] if ep_num_redundant_experts == 0 else []), @@ -153,8 +162,11 @@ def analyze_actual_utilization_rate(dir_data: Path, num_gpu: int): physical_count_of_whatever=physical_count_of_forward_pass, num_gpu=num_gpu, ) + print(f"{gpu_physical_count_of_forward_pass.shape=}") utilization_rate = compute_utilization_rate(gpu_physical_count_of_forward_pass) print(f"{utilization_rate.shape=}") + print(f"{torch.mean(utilization_rate, dim=0)=}") + print(f"{torch.mean(utilization_rate[:, 3:]).item()=}") print(dir_data, torch.mean(utilization_rate).item()) @@ -175,12 +187,16 @@ def simulate_execution( model_config_for_expert_location.num_logical_experts + server_args.ep_num_redundant_experts ) + if (x := server_args.init_expert_location) is not None: + print(f"Compute logical_count from {x}") + logical_count = json.loads(Path(x).read_text())["logical_count"] + else: + print(f"Compute logical_count from logical_count_of_seq") + logical_count = einops.einsum(logical_count_of_seq, + "num_seq num_layer num_expert -> num_layer num_expert", ) expert_location_metadata = MyExpertLocationMetadata.init_by_eplb( server_args, - logical_count=einops.einsum( - logical_count_of_seq, - "num_seq num_layer num_expert -> num_layer num_expert", - ), + logical_count=logical_count, num_physical_experts=num_physical_expert, ) # print(f"hi {expert_location_metadata=}") From 89032b6715f37b36d60d15ff91de6dde0ce84185 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 13 Apr 2025 20:44:26 +0800 Subject: [PATCH 0730/1089] more --- python/sglang/srt/managers/eplb_simulator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/eplb_simulator.py b/python/sglang/srt/managers/eplb_simulator.py index 98a81f6e641..733fdfc7b85 100644 --- a/python/sglang/srt/managers/eplb_simulator.py +++ b/python/sglang/srt/managers/eplb_simulator.py @@ -115,7 +115,7 @@ def scan_combinations( init_expert_location=init_expert_location, ) - for init_expert_location in [None, "/host_home/temp_sglang_server2local/1744461420780309768.json"] + for init_expert_location in ["/host_home/temp_sglang_server2local/1744461420780309768.json", None] # decode # for ep_num_redundant_experts in [0, 32] @@ -189,7 +189,7 @@ def simulate_execution( ) if (x := server_args.init_expert_location) is not None: print(f"Compute logical_count from {x}") - logical_count = json.loads(Path(x).read_text())["logical_count"] + logical_count = torch.tensor(json.loads(Path(x).read_text())["logical_count"]) else: print(f"Compute logical_count from logical_count_of_seq") logical_count = einops.einsum(logical_count_of_seq, From 3a42504661d73c61513abb41904a7df17eaa1767 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 13 Apr 2025 20:48:46 +0800 Subject: [PATCH 0731/1089] more --- python/sglang/srt/managers/eplb_simulator.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/managers/eplb_simulator.py b/python/sglang/srt/managers/eplb_simulator.py index 733fdfc7b85..33f65fd54f0 100644 --- a/python/sglang/srt/managers/eplb_simulator.py +++ b/python/sglang/srt/managers/eplb_simulator.py @@ -188,15 +188,15 @@ def simulate_execution( + server_args.ep_num_redundant_experts ) if (x := server_args.init_expert_location) is not None: - print(f"Compute logical_count from {x}") - logical_count = torch.tensor(json.loads(Path(x).read_text())["logical_count"]) + print(f"Compute eplb_input_logical_count from {x}") + eplb_input_logical_count = torch.tensor(json.loads(Path(x).read_text())["logical_count"]) else: - print(f"Compute logical_count from logical_count_of_seq") - logical_count = einops.einsum(logical_count_of_seq, - "num_seq num_layer num_expert -> num_layer num_expert", ) + print(f"Compute eplb_input_logical_count from logical_count_of_seq") + eplb_input_logical_count = einops.einsum(logical_count_of_seq, + "num_seq num_layer num_expert -> num_layer num_expert", ) expert_location_metadata = MyExpertLocationMetadata.init_by_eplb( server_args, - logical_count=logical_count, + logical_count=eplb_input_logical_count, num_physical_experts=num_physical_expert, ) # print(f"hi {expert_location_metadata=}") From 72ad03dc46e02bc63a9159ee5a77953e717587d1 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 13 Apr 2025 20:49:55 +0800 Subject: [PATCH 0732/1089] more --- python/sglang/srt/managers/eplb_simulator.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/eplb_simulator.py b/python/sglang/srt/managers/eplb_simulator.py index 33f65fd54f0..189ec95d194 100644 --- a/python/sglang/srt/managers/eplb_simulator.py +++ b/python/sglang/srt/managers/eplb_simulator.py @@ -173,6 +173,7 @@ def analyze_actual_utilization_rate(dir_data: Path, num_gpu: int): def simulate_execution( logical_count_of_seq: torch.Tensor, server_args: MyServerArgs, + override_eplb_input_logical_count: Optional[torch.Tensor] = None, ): model_config_for_expert_location = _MY_MODEL_CONFIG_FOR_EXPERT_LOCATION @@ -187,13 +188,18 @@ def simulate_execution( model_config_for_expert_location.num_logical_experts + server_args.ep_num_redundant_experts ) - if (x := server_args.init_expert_location) is not None: + + if server_args.init_expert_location == "from_variable": + print(f"Compute eplb_input_logical_count from override_eplb_input_logical_count") + eplb_input_logical_count = override_eplb_input_logical_count + elif (x := server_args.init_expert_location) is not None: print(f"Compute eplb_input_logical_count from {x}") eplb_input_logical_count = torch.tensor(json.loads(Path(x).read_text())["logical_count"]) else: print(f"Compute eplb_input_logical_count from logical_count_of_seq") eplb_input_logical_count = einops.einsum(logical_count_of_seq, "num_seq num_layer num_expert -> num_layer num_expert", ) + expert_location_metadata = MyExpertLocationMetadata.init_by_eplb( server_args, logical_count=eplb_input_logical_count, From 7f057b9741ce915cbf9a0aaa6120b3e28fb96dcc Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 13 Apr 2025 20:50:07 +0800 Subject: [PATCH 0733/1089] more --- python/sglang/srt/managers/eplb_simulator.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/eplb_simulator.py b/python/sglang/srt/managers/eplb_simulator.py index 189ec95d194..7641b803e8d 100644 --- a/python/sglang/srt/managers/eplb_simulator.py +++ b/python/sglang/srt/managers/eplb_simulator.py @@ -102,6 +102,7 @@ def read_physical_count_of_forward_pass(dir_data: Path): def scan_combinations( logical_count_of_seq: torch.Tensor, + override_eplb_input_logical_count: Optional[torch.Tensor] = None, ): num_gpu_per_node = 8 server_args_list = [ @@ -147,7 +148,8 @@ def scan_combinations( for server_args in server_args_list: print() info = simulate_execution( - logical_count_of_seq=logical_count_of_seq, server_args=server_args + logical_count_of_seq=logical_count_of_seq, server_args=server_args, + override_eplb_input_logical_count=override_eplb_input_logical_count, ) print(f"{server_args=} {info=}") rows.append(dict(**dataclasses.asdict(server_args), **info)) From 8c074224d4e089645d194f415d88f4e18efbfbfc Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 13 Apr 2025 20:50:57 +0800 Subject: [PATCH 0734/1089] more --- python/sglang/srt/managers/eplb_simulator.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/eplb_simulator.py b/python/sglang/srt/managers/eplb_simulator.py index 7641b803e8d..1664dedc9f5 100644 --- a/python/sglang/srt/managers/eplb_simulator.py +++ b/python/sglang/srt/managers/eplb_simulator.py @@ -116,7 +116,8 @@ def scan_combinations( init_expert_location=init_expert_location, ) - for init_expert_location in ["/host_home/temp_sglang_server2local/1744461420780309768.json", None] + # for init_expert_location in ["/host_home/temp_sglang_server2local/1744461420780309768.json", None] + for init_expert_location in ["from_variable"] # decode # for ep_num_redundant_experts in [0, 32] From 3df6736e5d08f79b551cdacb35c9ce0cd52ece46 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 08:52:52 +0800 Subject: [PATCH 0735/1089] more --- python/sglang/srt/entrypoints/engine.py | 174 ++++++++++++------------ 1 file changed, 88 insertions(+), 86 deletions(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 7e5a1813f37..a472509866a 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -33,7 +33,6 @@ import zmq import zmq.asyncio from PIL.Image import Image - from sglang.srt.configs.model_config import ModelConfig from sglang.srt.managers.eplb_manager import EPLBManager from sglang.srt.managers.expert_location import ExpertLocationMetadata @@ -138,33 +137,33 @@ def __init__(self, **kwargs): ) def generate( - self, - # The input prompt. It can be a single prompt or a batch of prompts. - prompt: Optional[Union[List[str], str]] = None, - sampling_params: Optional[Union[List[Dict], Dict]] = None, - # The token ids for text; one can either specify text or input_ids. - input_ids: Optional[Union[List[List[int]], List[int]]] = None, - # The image input. It can be an image instance, file name, URL, or base64 encoded string. - # Can be formatted as: - # - Single image for a single request - # - List of images (one per request in a batch) - # - List of lists of images (multiple images per request) - # See also python/sglang/srt/utils.py:load_image for more details. - image_data: Optional[ - Union[ - List[List[Union[Image, str]]], - List[Union[Image, str]], - Union[Image, str], - ] - ] = None, - return_logprob: Optional[Union[List[bool], bool]] = False, - logprob_start_len: Optional[Union[List[int], int]] = None, - top_logprobs_num: Optional[Union[List[int], int]] = None, - token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None, - lora_path: Optional[List[Optional[str]]] = None, - custom_logit_processor: Optional[Union[List[str], str]] = None, - return_hidden_states: bool = False, - stream: bool = False, + self, + # The input prompt. It can be a single prompt or a batch of prompts. + prompt: Optional[Union[List[str], str]] = None, + sampling_params: Optional[Union[List[Dict], Dict]] = None, + # The token ids for text; one can either specify text or input_ids. + input_ids: Optional[Union[List[List[int]], List[int]]] = None, + # The image input. It can be an image instance, file name, URL, or base64 encoded string. + # Can be formatted as: + # - Single image for a single request + # - List of images (one per request in a batch) + # - List of lists of images (multiple images per request) + # See also python/sglang/srt/utils.py:load_image for more details. + image_data: Optional[ + Union[ + List[List[Union[Image, str]]], + List[Union[Image, str]], + Union[Image, str], + ] + ] = None, + return_logprob: Optional[Union[List[bool], bool]] = False, + logprob_start_len: Optional[Union[List[int], int]] = None, + top_logprobs_num: Optional[Union[List[int], int]] = None, + token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None, + lora_path: Optional[List[Optional[str]]] = None, + custom_logit_processor: Optional[Union[List[str], str]] = None, + return_hidden_states: bool = False, + stream: bool = False, ) -> Union[Dict, Iterator[Dict]]: """ The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`. @@ -203,32 +202,32 @@ def generator_wrapper(): return ret async def async_generate( - self, - # The input prompt. It can be a single prompt or a batch of prompts. - prompt: Optional[Union[List[str], str]] = None, - sampling_params: Optional[Union[List[Dict], Dict]] = None, - # The token ids for text; one can either specify text or input_ids. - input_ids: Optional[Union[List[List[int]], List[int]]] = None, - # The image input. It can be an image instance, file name, URL, or base64 encoded string. - # Can be formatted as: - # - Single image for a single request - # - List of images (one per request in a batch) - # - List of lists of images (multiple images per request) - # See also python/sglang/srt/utils.py:load_image for more details. - image_data: Optional[ - Union[ - List[List[Union[Image, str]]], - List[Union[Image, str]], - Union[Image, str], - ] - ] = None, - return_logprob: Optional[Union[List[bool], bool]] = False, - logprob_start_len: Optional[Union[List[int], int]] = None, - top_logprobs_num: Optional[Union[List[int], int]] = None, - token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None, - lora_path: Optional[List[Optional[str]]] = None, - custom_logit_processor: Optional[Union[List[str], str]] = None, - stream: bool = False, + self, + # The input prompt. It can be a single prompt or a batch of prompts. + prompt: Optional[Union[List[str], str]] = None, + sampling_params: Optional[Union[List[Dict], Dict]] = None, + # The token ids for text; one can either specify text or input_ids. + input_ids: Optional[Union[List[List[int]], List[int]]] = None, + # The image input. It can be an image instance, file name, URL, or base64 encoded string. + # Can be formatted as: + # - Single image for a single request + # - List of images (one per request in a batch) + # - List of lists of images (multiple images per request) + # See also python/sglang/srt/utils.py:load_image for more details. + image_data: Optional[ + Union[ + List[List[Union[Image, str]]], + List[Union[Image, str]], + Union[Image, str], + ] + ] = None, + return_logprob: Optional[Union[List[bool], bool]] = False, + logprob_start_len: Optional[Union[List[int], int]] = None, + top_logprobs_num: Optional[Union[List[int], int]] = None, + token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None, + lora_path: Optional[List[Optional[str]]] = None, + custom_logit_processor: Optional[Union[List[str], str]] = None, + stream: bool = False, ) -> Union[Dict, AsyncIterator[Dict]]: """ The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`. @@ -255,15 +254,15 @@ async def async_generate( return await generator.__anext__() def encode( - self, - prompt: Union[str, List[str], List[Dict], List[List[Dict]]], - image_data: Optional[ - Union[ - List[List[Union[Image, str]]], - List[Union[Image, str]], - Union[Image, str], - ] - ] = None, + self, + prompt: Union[str, List[str], List[Dict], List[List[Dict]]], + image_data: Optional[ + Union[ + List[List[Union[Image, str]]], + List[Union[Image, str]], + Union[Image, str], + ] + ] = None, ) -> Dict: """ The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`. @@ -311,13 +310,13 @@ def get_server_info(self): } def init_weights_update_group( - self, - master_address: str, - master_port: int, - rank_offset: int, - world_size: int, - group_name: str, - backend: str = "nccl", + self, + master_address: str, + master_port: int, + rank_offset: int, + world_size: int, + group_name: str, + backend: str = "nccl", ): """Initialize parameter update group.""" obj = InitWeightsUpdateGroupReqInput( @@ -346,10 +345,10 @@ def update_weights_from_distributed(self, name: str, dtype, shape): ) def update_weights_from_tensor( - self, - named_tensors: List[Tuple[str, torch.Tensor]], - load_format: Optional[str] = None, - flush_cache: bool = True, + self, + named_tensors: List[Tuple[str, torch.Tensor]], + load_format: Optional[str] = None, + flush_cache: bool = True, ): """Update weights from distributed source. If there are going to be more updates, set `flush_cache` to be true to avoid duplicated operations such as clearing cache.""" @@ -386,10 +385,10 @@ def update_expert_location(self, expert_location_metadata: ExpertLocationMetadat ) def update_weights_from_disk( - self, - model_path: str, - load_format: Optional[str] = None, - param_categories: Optional[List[str]] = None, + self, + model_path: str, + load_format: Optional[str] = None, + param_categories: Optional[List[str]] = None, ): """Update the weights from disk inplace without re-launching the engine. @@ -508,7 +507,7 @@ def sigquit_handler(signum, frame): def _launch_subprocesses( - server_args: ServerArgs, port_args: Optional[PortArgs] = None + server_args: ServerArgs, port_args: Optional[PortArgs] = None ) -> Tuple[TokenizerManager, Dict]: """ Launch the TokenizerManager in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess. @@ -528,10 +527,13 @@ def _launch_subprocesses( server_args.model_path, server_args.tokenizer_path ) - eplb_manager = EPLBManager(server_args) if server_args.enable_eplb else None - expert_location_metadata = _compute_initial_expert_location_metadata( - server_args, eplb_manager - ) + if server_args.node_rank == 0: + eplb_manager = EPLBManager(server_args) if server_args.enable_eplb else None + expert_location_metadata = _compute_initial_expert_location_metadata( + server_args, eplb_manager + ) + else: + eplb_manager = expert_location_metadata = None scheduler_procs = [] if server_args.dp_size == 1: @@ -549,8 +551,8 @@ def _launch_subprocesses( for tp_rank in tp_rank_range: reader, writer = mp.Pipe(duplex=False) gpu_id = ( - server_args.base_gpu_id - + (tp_rank % tp_size_per_node) * server_args.gpu_id_step + server_args.base_gpu_id + + (tp_rank % tp_size_per_node) * server_args.gpu_id_step ) proc = mp.Process( target=run_scheduler_process, @@ -648,7 +650,7 @@ def _launch_subprocesses( def _compute_initial_expert_location_metadata( - server_args: ServerArgs, eplb_manager: EPLBManager + server_args: ServerArgs, eplb_manager: EPLBManager ) -> ExpertLocationMetadata: if (data := server_args.init_expert_location) is not None: try: From acf3209a95b25c3d41fee580d7df3d5989dbf13e Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 08:53:44 +0800 Subject: [PATCH 0736/1089] more --- python/sglang/srt/managers/data_parallel_controller.py | 9 +++++---- python/sglang/srt/managers/scheduler.py | 4 ++-- python/sglang/srt/managers/tokenizer_manager.py | 2 +- python/sglang/srt/managers/tp_worker.py | 2 +- python/sglang/srt/managers/tp_worker_overlap_thread.py | 2 +- python/sglang/srt/model_executor/model_runner.py | 2 +- 6 files changed, 11 insertions(+), 10 deletions(-) diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index 2ff04bd3872..af7aba342a4 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -18,6 +18,7 @@ import signal import threading from enum import Enum, auto +from typing import Optional import psutil import setproctitle @@ -60,7 +61,7 @@ def __init__( self, server_args: ServerArgs, port_args: PortArgs, - expert_location_metadata: ExpertLocationMetadata, + expert_location_metadata: Optional[ExpertLocationMetadata], ) -> None: # Parse args self.max_total_num_tokens = None @@ -163,7 +164,7 @@ def launch_tensor_parallel_group_thread( self, server_args: ServerArgs, port_args: PortArgs, - expert_location_metadata: ExpertLocationMetadata, + expert_location_metadata: Optional[ExpertLocationMetadata], base_gpu_id: int, dp_rank: int, ready_event: threading.Event, @@ -193,7 +194,7 @@ def launch_tensor_parallel_group( self, server_args: ServerArgs, port_args: PortArgs, - expert_location_metadata: ExpertLocationMetadata, + expert_location_metadata: Optional[ExpertLocationMetadata], base_gpu_id: int, dp_rank: int, ): @@ -289,7 +290,7 @@ def event_loop(self): def run_data_parallel_controller_process( server_args: ServerArgs, port_args: PortArgs, - expert_location_metadata: ExpertLocationMetadata, + expert_location_metadata: Optional[ExpertLocationMetadata], pipe_writer, ): setproctitle.setproctitle("sglang::data_parallel_controller") diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 55de6b4a7bc..2a4ea129180 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -171,7 +171,7 @@ def __init__( self, server_args: ServerArgs, port_args: PortArgs, - expert_location_metadata: ExpertLocationMetadata, + expert_location_metadata: Optional[ExpertLocationMetadata], gpu_id: int, tp_rank: int, dp_rank: Optional[int], @@ -2028,7 +2028,7 @@ def _import_static_state(model, static_params): def run_scheduler_process( server_args: ServerArgs, port_args: PortArgs, - expert_location_metadata: ExpertLocationMetadata, + expert_location_metadata: Optional[ExpertLocationMetadata], gpu_id: int, tp_rank: int, dp_rank: Optional[int], diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 2d08fdb8ebb..96b3285dd9c 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -147,7 +147,7 @@ def __init__( self, server_args: ServerArgs, port_args: PortArgs, - expert_location_metadata: ExpertLocationMetadata, + expert_location_metadata: Optional[ExpertLocationMetadata], eplb_manager: Optional[EPLBManager], ): # Parse args diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 76c0f6bd91a..dde03972ed7 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -46,7 +46,7 @@ class TpModelWorker: def __init__( self, server_args: ServerArgs, - expert_location_metadata: ExpertLocationMetadata, + expert_location_metadata: Optional[ExpertLocationMetadata], gpu_id: int, tp_rank: int, dp_rank: Optional[int], diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index ab1e3c53eee..d4217d9f4fb 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -55,7 +55,7 @@ class TpModelWorkerClient: def __init__( self, server_args: ServerArgs, - expert_location_metadata: ExpertLocationMetadata, + expert_location_metadata: Optional[ExpertLocationMetadata], gpu_id: int, tp_rank: int, dp_rank: Optional[int], diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 5c566b333b0..68af99aa8cb 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -108,7 +108,7 @@ class ModelRunner: def __init__( self, model_config: ModelConfig, - expert_location_metadata: ExpertLocationMetadata, + expert_location_metadata: Optional[ExpertLocationMetadata], mem_fraction_static: float, gpu_id: int, tp_rank: int, From 04ca5e341bc5b1e824e80842de861c946e7deba4 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 08:54:11 +0800 Subject: [PATCH 0737/1089] more --- python/sglang/srt/model_executor/model_runner.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 68af99aa8cb..f3b712543bb 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -91,7 +91,7 @@ monkey_patch_p2p_access_check, monkey_patch_vllm_gguf_config, set_cpu_offload_max_bytes, - set_cuda_arch, + set_cuda_arch, broadcast_pyobj, ) logger = logging.getLogger(__name__) @@ -192,6 +192,7 @@ def __init__( # Get memory before model loading min_per_gpu_memory = self.init_torch_distributed() + [expert_location_metadata] = broadcast_pyobj(data=[expert_location_metadata], rank=torch.distributed.get_rank()) expert_location_metadata.to(server_args.device) set_global_expert_location_metadata(expert_location_metadata) if self.tp_rank == 0 and get_bool_env_var( From bb2cbe593516e452fe10d4075b201c3722b7c10d Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 08:56:04 +0800 Subject: [PATCH 0738/1089] more --- python/sglang/srt/server_args.py | 76 +++++++++++++++++--------------- 1 file changed, 41 insertions(+), 35 deletions(-) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 6993987d556..f817b232da8 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -167,6 +167,7 @@ class ServerArgs: enable_eplb: bool = False eplb_storage_dir: str = "/tmp/eplb_storage" eplb_rebalance_period: Optional[int] = None + enable_expert_distribution_recorder: bool = False enable_torch_compile: bool = False torch_compile_max_bs: int = 32 cuda_graph_max_bs: Optional[int] = None @@ -431,8 +432,8 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.tokenizer_mode, choices=["auto", "slow"], help="Tokenizer mode. 'auto' will use the fast " - "tokenizer if available, and 'slow' will " - "always use the slow tokenizer.", + "tokenizer if available, and 'slow' will " + "always use the slow tokenizer.", ) parser.add_argument( "--skip-tokenizer-init", @@ -456,21 +457,21 @@ def add_cli_args(parser: argparse.ArgumentParser): "remote", ], help="The format of the model weights to load. " - '"auto" will try to load the weights in the safetensors format ' - "and fall back to the pytorch bin format if safetensors format " - "is not available. " - '"pt" will load the weights in the pytorch bin format. ' - '"safetensors" will load the weights in the safetensors format. ' - '"npcache" will load the weights in pytorch format and store ' - "a numpy cache to speed up the loading. " - '"dummy" will initialize the weights with random values, ' - "which is mainly for profiling." - '"gguf" will load the weights in the gguf format. ' - '"bitsandbytes" will load the weights using bitsandbytes ' - "quantization." - '"layered" loads weights layer by layer so that one can quantize a ' - "layer before loading another to make the peak memory envelope " - "smaller.", + '"auto" will try to load the weights in the safetensors format ' + "and fall back to the pytorch bin format if safetensors format " + "is not available. " + '"pt" will load the weights in the pytorch bin format. ' + '"safetensors" will load the weights in the safetensors format. ' + '"npcache" will load the weights in pytorch format and store ' + "a numpy cache to speed up the loading. " + '"dummy" will initialize the weights with random values, ' + "which is mainly for profiling." + '"gguf" will load the weights in the gguf format. ' + '"bitsandbytes" will load the weights using bitsandbytes ' + "quantization." + '"layered" loads weights layer by layer so that one can quantize a ' + "layer before loading another to make the peak memory envelope " + "smaller.", ) parser.add_argument( "--trust-remote-code", @@ -483,13 +484,13 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.dtype, choices=["auto", "half", "float16", "bfloat16", "float", "float32"], help="Data type for model weights and activations.\n\n" - '* "auto" will use FP16 precision for FP32 and FP16 models, and ' - "BF16 precision for BF16 models.\n" - '* "half" for FP16. Recommended for AWQ quantization.\n' - '* "float16" is the same as "half".\n' - '* "bfloat16" for a balance between precision and range.\n' - '* "float" is shorthand for FP32 precision.\n' - '* "float32" for FP32 precision.', + '* "auto" will use FP16 precision for FP32 and FP16 models, and ' + "BF16 precision for BF16 models.\n" + '* "half" for FP16. Recommended for AWQ quantization.\n' + '* "float16" is the same as "half".\n' + '* "bfloat16" for a balance between precision and range.\n' + '* "float" is shorthand for FP32 precision.\n' + '* "float32" for FP32 precision.', ) parser.add_argument( "--kv-cache-dtype", @@ -524,9 +525,9 @@ def add_cli_args(parser: argparse.ArgumentParser): type=nullable_str, default=None, help="Path to the JSON file containing the KV cache " - "scaling factors. This should generally be supplied, when " - "KV cache dtype is FP8. Otherwise, KV cache scaling factors " - "default to 1.0, which may cause accuracy issues. ", + "scaling factors. This should generally be supplied, when " + "KV cache dtype is FP8. Otherwise, KV cache scaling factors " + "default to 1.0, which may cause accuracy issues. ", ) parser.add_argument( "--context-length", @@ -568,8 +569,8 @@ def add_cli_args(parser: argparse.ArgumentParser): type=str, default=None, help="The specific model version to use. It can be a branch " - "name, a tag name, or a commit id. If unspecified, will use " - "the default version.", + "name, a tag name, or a commit id. If unspecified, will use " + "the default version.", ) # Memory and scheduling parser.add_argument( @@ -589,7 +590,7 @@ def add_cli_args(parser: argparse.ArgumentParser): type=int, default=ServerArgs.max_total_tokens, help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. " - "This option is typically used for development and debugging purposes.", + "This option is typically used for development and debugging purposes.", ) parser.add_argument( "--chunked-prefill-size", @@ -1011,6 +1012,11 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Enabling expert parallelism for moe. The ep size is equal to the tp size.", ) + parser.add_argument( + "--enable-expert-distribution-recorder", + action="store_true", + help="Enable expert distribution recorder", + ) parser.add_argument( "--enable-torch-compile", action="store_true", @@ -1054,7 +1060,7 @@ def add_cli_args(parser: argparse.ArgumentParser): "--triton-attention-reduce-in-fp32", action="store_true", help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16." - "This only affects Triton attention kernels.", + "This only affects Triton attention kernels.", ) parser.add_argument( "--triton-attention-num-kv-splits", @@ -1067,8 +1073,8 @@ def add_cli_args(parser: argparse.ArgumentParser): type=int, default=ServerArgs.num_continuous_decode_steps, help="Run multiple continuous decoding steps to reduce scheduling overhead. " - "This can potentially increase throughput but may also increase time-to-first-token latency. " - "The default value is 1, meaning only run one decoding step at a time.", + "This can potentially increase throughput but may also increase time-to-first-token latency. " + "The default value is 1, meaning only run one decoding step at a time.", ) parser.add_argument( "--delete-ckpt-after-loading", @@ -1162,7 +1168,7 @@ def add_cli_args(parser: argparse.ArgumentParser): type=int, default=0, help="The number of shared_experts need to be replica to fuse with normal experts in deepseek v3/r1 " - "we use tp_size by default.", + "we use tp_size by default.", ) parser.add_argument( "--disable-shared-experts-fusion", @@ -1181,7 +1187,7 @@ def add_cli_args(parser: argparse.ArgumentParser): type=str, required=False, help="Specify custom warmup functions (csv) to run before server starts eg. --warmups=warmup_name1,warmup_name2 " - "will run the functions `warmup_name1` and `warmup_name2` specified in warmup.py before the server starts listening for requests", + "will run the functions `warmup_name1` and `warmup_name2` specified in warmup.py before the server starts listening for requests", ) # Debug tensor dumps From 4cdb32dd4fb94570c657ce05d74069d196241032 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 08:56:41 +0800 Subject: [PATCH 0739/1089] more --- python/sglang/srt/server_args.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index f817b232da8..8c0683c01ae 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -334,6 +334,10 @@ def __post_init__(self): f"DeepEP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]." ) + if self.enable_eplb: + self.enable_scheduler_input_blocker = True + self.enable_expert_distribution_recorder = True + if self.ep_num_redundant_experts > 0: assert ( self.enable_deepep_moe From 0d48b7634337d6a8d37e4b3cee5150c82cdc6167 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 08:57:06 +0800 Subject: [PATCH 0740/1089] more --- python/sglang/srt/server_args.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 8c0683c01ae..dc92d2cd7a6 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -337,6 +337,9 @@ def __post_init__(self): if self.enable_eplb: self.enable_scheduler_input_blocker = True self.enable_expert_distribution_recorder = True + logger.info( + f"EPLB is enabled. The enable_scheduler_input_blocker and enable_expert_distribution_recorder are automatically enabled." + ) if self.ep_num_redundant_experts > 0: assert ( From 897157cf6d11a55e97cbb97bf075fe49d1864850 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 08:57:36 +0800 Subject: [PATCH 0741/1089] more --- python/sglang/srt/model_executor/model_runner.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index f3b712543bb..622e759fbb6 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -24,7 +24,6 @@ import torch import torch.distributed as dist - from sglang.srt.configs.device_config import DeviceConfig from sglang.srt.configs.load_config import LoadConfig from sglang.srt.configs.model_config import AttentionArch, ModelConfig @@ -187,7 +186,7 @@ def __init__( ) # CPU offload - set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3)) + set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024 ** 3)) # Get memory before model loading min_per_gpu_memory = self.init_torch_distributed() @@ -211,10 +210,6 @@ def initialize(self, min_per_gpu_memory: float): enable=self.server_args.enable_memory_saver ) - # Load the model - self.sampler = Sampler() - self.load_model() - expert_distribution_recorder.initialize( server_args, get_global_expert_location_metadata(), @@ -222,6 +217,10 @@ def initialize(self, min_per_gpu_memory: float): rank=self.tp_rank, ) + # Load the model + self.sampler = Sampler() + self.load_model() + # Apply torchao quantization torchao_applied = getattr(self.model, "torchao_applied", False) # In layered loading, torchao may have been applied @@ -969,7 +968,7 @@ def init_double_sparsity_channel_config(self, selected_channel): key = "model.layers." + str(i) + ".self_attn" + selected_channel self.sorted_channels.append( torch.tensor(channel_config[key])[ - :, : self.server_args.ds_heavy_channel_num + :, : self.server_args.ds_heavy_channel_num ] .contiguous() .cuda() From e9b0df4ef55e5a4a5bb8db1798b25fe6391fbf9b Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 08:58:20 +0800 Subject: [PATCH 0742/1089] more --- python/sglang/srt/managers/expert_distribution.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 941999c654a..f2f90bd513d 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -130,7 +130,17 @@ def dump_record(self): return output -expert_distribution_recorder = _ExpertDistributionRecorder() +_global_expert_distribution_recorder: Optional[_ExpertDistributionRecorder] = None + + +def get_global_expert_distribution_recorder(): + return _global_expert_distribution_recorder + + +def set_global_expert_distribution_recorder(value): + global _global_expert_distribution_recorder + assert _global_expert_distribution_recorder is None + _global_expert_distribution_recorder = value def postprocess_dumps( From 2a007224fd7f0ddcd3cdabb34efa5cb08e82ea3d Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 08:59:04 +0800 Subject: [PATCH 0743/1089] more --- python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py | 6 +++--- python/sglang/srt/layers/moe/topk.py | 4 ++-- python/sglang/srt/managers/scheduler.py | 8 ++++---- python/sglang/srt/model_executor/model_runner.py | 6 +++--- python/sglang/srt/models/deepseek_v2.py | 4 ++-- python/sglang/srt/models/qwen2_moe.py | 4 ++-- 6 files changed, 16 insertions(+), 16 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index 539da92f853..ca1cf67bdb4 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -1,4 +1,4 @@ -from sglang.srt.managers.expert_distribution import expert_distribution_recorder +from sglang.srt.managers.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.utils import DeepEPMode, DisposibleTensor try: @@ -265,7 +265,7 @@ def _dispatch_core( allocate_on_comm_stream=(previous_event is not None) and self.async_finish, ) - expert_distribution_recorder.on_deepep_dispatch_normal( + get_global_expert_distribution_recorder().on_deepep_dispatch_normal( num_recv_tokens_per_expert_list ) @@ -426,7 +426,7 @@ def dispatch_b( ): hook() if self.return_recv_hook else event.current_stream_wait() - expert_distribution_recorder.on_deepep_dispatch_low_latency(masked_m) + get_global_expert_distribution_recorder().on_deepep_dispatch_low_latency(masked_m) reorder_topk_ids = seg_indptr = None diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 071f056169a..2061d5f36f3 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -16,7 +16,7 @@ import torch import torch.nn.functional as F -from sglang.srt.managers.expert_distribution import expert_distribution_recorder +from sglang.srt.managers.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip, get_bool_env_var @@ -311,7 +311,7 @@ def select_experts( else: topk_ids = expert_logical_to_rank_dispatch_physical_map[topk_ids] - expert_distribution_recorder.on_select_experts(topk_ids=topk_ids) + get_global_expert_distribution_recorder().on_select_experts(topk_ids=topk_ids) return topk_weights, topk_ids diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 2a4ea129180..55563dea09c 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -53,7 +53,7 @@ from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.layers.dp_attention import compute_dp_attention_world_info from sglang.srt.layers.logits_processor import LogitsProcessorOutput -from sglang.srt.managers.expert_distribution import expert_distribution_recorder +from sglang.srt.managers.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.managers.expert_location import ExpertLocationMetadata from sglang.srt.managers.io_struct import ( AbortReq, @@ -1974,11 +1974,11 @@ def stop_profile(self) -> None: def expert_distribution_handle(self, recv_req: ExpertDistributionReq): dump_output = None if recv_req == ExpertDistributionReq.START_RECORD: - expert_distribution_recorder.start_record() + get_global_expert_distribution_recorder().start_record() elif recv_req == ExpertDistributionReq.STOP_RECORD: - expert_distribution_recorder.stop_record() + get_global_expert_distribution_recorder().stop_record() elif recv_req == ExpertDistributionReq.DUMP_RECORD: - dump_output = expert_distribution_recorder.dump_record() + dump_output = get_global_expert_distribution_recorder().dump_record() else: raise ValueError("Unrecognized ExpertDistributionReq value") return ExpertDistributionReqOutput(dump_output=dump_output) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 622e759fbb6..542f3f72654 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -44,7 +44,7 @@ from sglang.srt.layers.sampler import Sampler from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model from sglang.srt.lora.lora_manager import LoRAManager -from sglang.srt.managers.expert_distribution import expert_distribution_recorder +from sglang.srt.managers.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.managers.expert_location import ExpertLocationMetadata from sglang.srt.managers.io_struct import UpdateExpertLocationReqInput from sglang.srt.managers.schedule_batch import ( @@ -210,7 +210,7 @@ def initialize(self, min_per_gpu_memory: float): enable=self.server_args.enable_memory_saver ) - expert_distribution_recorder.initialize( + get_global_expert_distribution_recorder().initialize( server_args, get_global_expert_location_metadata(), # TODO handle DP!=TP case @@ -1052,7 +1052,7 @@ def forward( self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False ) -> LogitsProcessorOutput: self.forward_pass_id += 1 - with expert_distribution_recorder.with_forward_pass(self.forward_pass_id): + with get_global_expert_distribution_recorder().with_forward_pass(self.forward_pass_id): return self._forward_raw(forward_batch, skip_attn_backend_init) def _forward_raw( diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index cd427dbfbeb..e97ff9367c1 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -67,7 +67,7 @@ ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.managers.expert_distribution import expert_distribution_recorder +from sglang.srt.managers.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.managers.expert_location import ( ExpertLocationMetadata, ModelConfigForExpertLocation, @@ -1386,7 +1386,7 @@ def forward( residual = None for i in range(len(self.layers)): - with expert_distribution_recorder.with_current_layer(i): + with get_global_expert_distribution_recorder().with_current_layer(i): layer = self.layers[i] hidden_states, residual = layer( positions, hidden_states, forward_batch, residual diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 4cce4f0aca3..00045645f38 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -44,7 +44,7 @@ ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.managers.expert_distribution import expert_distribution_recorder +from sglang.srt.managers.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.managers.expert_location import ( ExpertLocationMetadata, ModelConfigForExpertLocation, @@ -372,7 +372,7 @@ def forward( hidden_states = input_embeds residual = None for i in range(len(self.layers)): - with expert_distribution_recorder.with_current_layer(i): + with get_global_expert_distribution_recorder().with_current_layer(i): layer = self.layers[i] hidden_states, residual = layer( positions, hidden_states, forward_batch, residual From fb931be9252040ca9fc685bf1641b51a4d0fb597 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 08:59:38 +0800 Subject: [PATCH 0744/1089] more --- python/sglang/srt/managers/expert_distribution.py | 4 ++-- python/sglang/srt/model_executor/model_runner.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index f2f90bd513d..415161deb28 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -18,7 +18,7 @@ # --------------------------------------- Entrypoint ----------------------------------------- -class _ExpertDistributionRecorder: +class ExpertDistributionRecorder: """Global expert distribution recording""" def __init__(self): @@ -130,7 +130,7 @@ def dump_record(self): return output -_global_expert_distribution_recorder: Optional[_ExpertDistributionRecorder] = None +_global_expert_distribution_recorder: Optional[ExpertDistributionRecorder] = None def get_global_expert_distribution_recorder(): diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 542f3f72654..f1b7971c2c0 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -44,7 +44,8 @@ from sglang.srt.layers.sampler import Sampler from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model from sglang.srt.lora.lora_manager import LoRAManager -from sglang.srt.managers.expert_distribution import get_global_expert_distribution_recorder +from sglang.srt.managers.expert_distribution import get_global_expert_distribution_recorder, \ + set_global_expert_distribution_recorder, ExpertDistributionRecorder from sglang.srt.managers.expert_location import ExpertLocationMetadata from sglang.srt.managers.io_struct import UpdateExpertLocationReqInput from sglang.srt.managers.schedule_batch import ( @@ -210,12 +211,11 @@ def initialize(self, min_per_gpu_memory: float): enable=self.server_args.enable_memory_saver ) - get_global_expert_distribution_recorder().initialize( + set_global_expert_distribution_recorder(ExpertDistributionRecorder( server_args, get_global_expert_location_metadata(), - # TODO handle DP!=TP case rank=self.tp_rank, - ) + )) # Load the model self.sampler = Sampler() From a85ebf11d69c973baae6a00cde5ff21a15104345 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 08:59:52 +0800 Subject: [PATCH 0745/1089] more --- python/sglang/srt/managers/expert_distribution.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 415161deb28..b5d9ef87855 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -21,19 +21,17 @@ class ExpertDistributionRecorder: """Global expert distribution recording""" - def __init__(self): - self._recording = False - # TODO improve API - self._enable_in_cuda_graph = get_bool_env_var("SGLANG_EXPERT_DISTRIBUTION_RECORDER_ENABLE_IN_CUDA_GRAPH") - self._current_layer_idx = Withable() - self._current_debug_name = Withable() - - def initialize( + def __init__( self, server_args: ServerArgs, expert_location_metadata: "ExpertLocationMetadata", rank: int, ): + self._recording = False + # TODO improve API + self._enable_in_cuda_graph = get_bool_env_var("SGLANG_EXPERT_DISTRIBUTION_RECORDER_ENABLE_IN_CUDA_GRAPH") + self._current_layer_idx = Withable() + self._current_debug_name = Withable() self._server_args = server_args self._expert_location_metadata = expert_location_metadata self._accumulator = _Accumulator.init_new(expert_location_metadata, rank) From 2481317c6e8fd3e43a20615cb89614b5c2ab7664 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 08:59:58 +0800 Subject: [PATCH 0746/1089] more --- python/sglang/srt/managers/expert_distribution.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index b5d9ef87855..feb183343e1 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -27,13 +27,14 @@ def __init__( expert_location_metadata: "ExpertLocationMetadata", rank: int, ): + self._server_args = server_args + self._expert_location_metadata = expert_location_metadata + self._recording = False # TODO improve API self._enable_in_cuda_graph = get_bool_env_var("SGLANG_EXPERT_DISTRIBUTION_RECORDER_ENABLE_IN_CUDA_GRAPH") self._current_layer_idx = Withable() self._current_debug_name = Withable() - self._server_args = server_args - self._expert_location_metadata = expert_location_metadata self._accumulator = _Accumulator.init_new(expert_location_metadata, rank) self._single_pass_gatherers = { k: _SinglePassGatherer.init_new(server_args, expert_location_metadata) From d4ed201454cfc4870836f63e4c0d072fa671677b Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:00:30 +0800 Subject: [PATCH 0747/1089] more --- python/sglang/srt/managers/expert_distribution.py | 11 +++++++++++ python/sglang/srt/model_executor/model_runner.py | 2 +- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index feb183343e1..5fcf56e1e50 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -21,6 +21,17 @@ class ExpertDistributionRecorder: """Global expert distribution recording""" + @staticmethod + def init_new( + server_args: ServerArgs, + expert_location_metadata: "ExpertLocationMetadata", + rank: int, + ): + if server_args.enable_expert_distribution_recorder: + return TODO + else: + return TODO + def __init__( self, server_args: ServerArgs, diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index f1b7971c2c0..63141f3c908 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -211,7 +211,7 @@ def initialize(self, min_per_gpu_memory: float): enable=self.server_args.enable_memory_saver ) - set_global_expert_distribution_recorder(ExpertDistributionRecorder( + set_global_expert_distribution_recorder(ExpertDistributionRecorder.init_new( server_args, get_global_expert_location_metadata(), rank=self.tp_rank, From e86fe46daff517ba689acbc72d33320f7ec76f17 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:00:48 +0800 Subject: [PATCH 0748/1089] more --- python/sglang/srt/managers/expert_distribution.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 5fcf56e1e50..b80a4f16bc3 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -32,6 +32,12 @@ def init_new( else: return TODO + +class _ExpertDistributionRecorderNoop(ExpertDistributionRecorder): + pass + + +class _ExpertDistributionRecorderReal(ExpertDistributionRecorder): def __init__( self, server_args: ServerArgs, From 20bca60eaea41650583aeb79967e697da16ddf1f Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:00:59 +0800 Subject: [PATCH 0749/1089] more --- python/sglang/srt/managers/expert_distribution.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index b80a4f16bc3..57f236b81c8 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -28,9 +28,9 @@ def init_new( rank: int, ): if server_args.enable_expert_distribution_recorder: - return TODO + return _ExpertDistributionRecorderReal(server_args, expert_location_metadata, rank) else: - return TODO + return _ExpertDistributionRecorderNoop() class _ExpertDistributionRecorderNoop(ExpertDistributionRecorder): From ff273a8071d640741a2741fe9fa77a99301084e6 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:01:21 +0800 Subject: [PATCH 0750/1089] more --- python/sglang/srt/managers/expert_distribution.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 57f236b81c8..1ddde31194d 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -32,6 +32,14 @@ def init_new( else: return _ExpertDistributionRecorderNoop() + @contextmanager + def with_current_layer(self, layer_idx): + pass + + @contextmanager + def with_debug_name(self, debug_name): + pass + class _ExpertDistributionRecorderNoop(ExpertDistributionRecorder): pass From 35bdd360d01fba32052ab84adbcca63b1bf974ab Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:01:32 +0800 Subject: [PATCH 0751/1089] more --- python/sglang/srt/managers/expert_distribution.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 1ddde31194d..6b1b0ba53e2 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -34,11 +34,15 @@ def init_new( @contextmanager def with_current_layer(self, layer_idx): - pass + yield @contextmanager def with_debug_name(self, debug_name): - pass + yield + + @contextmanager + def with_forward_pass(self, forward_pass_id: int): + yield class _ExpertDistributionRecorderNoop(ExpertDistributionRecorder): From 5006c343dd45f1e5326d96ee15b75a5de38b1842 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:02:04 +0800 Subject: [PATCH 0752/1089] more --- .../sglang/srt/managers/expert_distribution.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 6b1b0ba53e2..19d029e8365 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -44,6 +44,24 @@ def with_debug_name(self, debug_name): def with_forward_pass(self, forward_pass_id: int): yield + def on_select_experts(self, topk_ids: torch.Tensor): + pass + + def on_deepep_dispatch_normal(self, num_recv_tokens_per_expert_list: List[int]): + pass + + def on_deepep_dispatch_low_latency(self, recv_count: torch.Tensor): + pass + + def start_record(self): + pass + + def stop_record(self): + pass + + def dump_record(self): + raise NotImplementedError + class _ExpertDistributionRecorderNoop(ExpertDistributionRecorder): pass From 71150b309b9670e3e33feb15908432be8a3891f8 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:03:22 +0800 Subject: [PATCH 0753/1089] more --- python/sglang/srt/managers/expert_distribution.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 19d029e8365..82d2993719d 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -54,13 +54,17 @@ def on_deepep_dispatch_low_latency(self, recv_count: torch.Tensor): pass def start_record(self): - pass + self._on_not_implemented() def stop_record(self): - pass + self._on_not_implemented() def dump_record(self): - raise NotImplementedError + self._on_not_implemented() + + def _on_not_implemented(self): + raise Exception( + "Please enable ServerArgs.enable_expert_distribution_recorder to use ExpertDistributionRecorder.") class _ExpertDistributionRecorderNoop(ExpertDistributionRecorder): From 588c492340c72b09283b913eee58aa870493a348 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:03:39 +0800 Subject: [PATCH 0754/1089] more --- python/sglang/srt/managers/expert_distribution.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 82d2993719d..cc8f404d071 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -82,8 +82,6 @@ def __init__( self._expert_location_metadata = expert_location_metadata self._recording = False - # TODO improve API - self._enable_in_cuda_graph = get_bool_env_var("SGLANG_EXPERT_DISTRIBUTION_RECORDER_ENABLE_IN_CUDA_GRAPH") self._current_layer_idx = Withable() self._current_debug_name = Withable() self._accumulator = _Accumulator.init_new(expert_location_metadata, rank) @@ -134,7 +132,7 @@ def on_deepep_dispatch_low_latency(self, recv_count: torch.Tensor): self._on_hook("on_deepep_dispatch_low_latency", recv_count=recv_count) def _on_hook(self, hook_name: str, **kwargs): - if not (self._recording or (self._enable_in_cuda_graph and torch.cuda.is_current_stream_capturing())): + if not (self._recording or torch.cuda.is_current_stream_capturing()): return gatherer = self._single_pass_gatherers[ self._accumulator.get_single_pass_gatherer_key( From 54bac87db9e35149c713c3f3d652d677f84c7595 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:05:35 +0800 Subject: [PATCH 0755/1089] more --- python/sglang/srt/managers/expert_location.py | 35 ++++++++++++------- 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index be7caf23b09..5b5a84d7b02 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -13,15 +13,26 @@ @dataclass class ExpertLocationMetadata: - num_layers: int - num_local_physical_experts: int - num_logical_experts: int physical_to_logical_map: torch.Tensor # (layers, num_physical_experts) logical_to_all_physical_map: torch.Tensor # (layers, num_logical_experts, X) logical_to_all_physical_map_num_valid: torch.Tensor # (layers, num_logical_experts) # (num_gpus, layers, num_logical_experts) logical_to_rank_dispatch_physical_map: torch.Tensor + # -------------------------------- properties ------------------------------------ + + @property + def num_layers(self) -> int: + return self.physical_to_logical_map.shape[0] + + @property + def num_local_physical_experts(self) -> int: + return TODO + + @property + def num_logical_experts(self) -> int: + return self.logical_to_all_physical_map.shape[1] + # -------------------------------- construction and mutation ------------------------------------ @staticmethod @@ -34,8 +45,8 @@ def init_trivial(server_args: ServerArgs): num_logical_experts = model_config_for_expert_location.num_logical_experts physical_to_logical_map = ( - torch.arange(0, num_physical_experts).repeat(num_layers, 1) - % num_logical_experts + torch.arange(0, num_physical_experts).repeat(num_layers, 1) + % num_logical_experts ) return ExpertLocationMetadata.init_by_mapping( @@ -108,8 +119,8 @@ def _init_common(server_args: ServerArgs): ) num_physical_experts = ( - model_config_for_expert_location.num_logical_experts - + server_args.ep_num_redundant_experts + model_config_for_expert_location.num_logical_experts + + server_args.ep_num_redundant_experts ) # TODO consider case when DP attention is disabled and DP > 1 world_size = server_args.tp_size @@ -162,7 +173,7 @@ def physical_to_local_physical(self, global_physical_expert_index: int): return global_physical_expert_index % self.num_local_physical_experts def logical_to_all_physical( - self, layer_id: int, logical_expert_id: int + self, layer_id: int, logical_expert_id: int ) -> List[int]: return self.logical_to_all_physical_raw( self.logical_to_all_physical_map, layer_id, logical_expert_id @@ -170,7 +181,7 @@ def logical_to_all_physical( @staticmethod def logical_to_all_physical_raw( - logical_to_all_physical_map, layer_id: int, logical_expert_id: int + logical_to_all_physical_map, layer_id: int, logical_expert_id: int ) -> List[int]: return [ physical_expert_id @@ -190,7 +201,7 @@ def debug_str(self): def _compute_logical_to_all_physical_map( - physical_to_logical_map: torch.Tensor, num_logical_experts: int + physical_to_logical_map: torch.Tensor, num_logical_experts: int ): # This is rarely called, so we use for loops for maximum clarity @@ -229,8 +240,8 @@ def _compute_logical_to_all_physical_map_num_valid(logical_to_all_physical_map): def _compute_logical_to_rank_dispatch_physical_map( - logical_to_all_physical_map: torch.Tensor, - num_gpus: int, + logical_to_all_physical_map: torch.Tensor, + num_gpus: int, ): # TODO maybe improve this algorithm (e.g. ensure it is really balanced) # This is rarely called, so we use for loops for maximum clarity From fd50ca69f55be1168d5d7cdbea01419f10c47e6b Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:05:56 +0800 Subject: [PATCH 0756/1089] more --- python/sglang/srt/managers/expert_location.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 5b5a84d7b02..92bf2029ab5 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -67,9 +67,6 @@ def init_by_mapping(server_args: ServerArgs, physical_to_logical_map): ) return ExpertLocationMetadata( - num_layers=model_config_for_expert_location.num_layers, - num_logical_experts=model_config_for_expert_location.num_logical_experts, - num_local_physical_experts=common["num_local_physical_experts"], physical_to_logical_map=physical_to_logical_map, logical_to_all_physical_map=logical_to_all_physical_map, logical_to_all_physical_map_num_valid=_compute_logical_to_all_physical_map_num_valid( @@ -98,9 +95,6 @@ def init_by_eplb(server_args: ServerArgs, logical_count: torch.Tensor): ) return ExpertLocationMetadata( - num_layers=model_config_for_expert_location.num_layers, - num_logical_experts=model_config_for_expert_location.num_logical_experts, - num_local_physical_experts=common["num_local_physical_experts"], physical_to_logical_map=physical_to_logical_map, logical_to_all_physical_map=logical_to_all_physical_map, logical_to_all_physical_map_num_valid=_compute_logical_to_all_physical_map_num_valid( @@ -135,12 +129,13 @@ def _init_common(server_args: ServerArgs): ) def update(self, other: "ExpertLocationMetadata"): - for field in [ - "num_layers", - "num_local_physical_experts", - "num_logical_experts", - ]: - assert getattr(self, field) == getattr(other, field) + # TODO remove + # for field in [ + # "num_layers", + # "num_local_physical_experts", + # "num_logical_experts", + # ]: + # assert getattr(self, field) == getattr(other, field) for field in [ "physical_to_logical_map", From de33987b31b3b116749448fa68303b3fee776791 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:06:22 +0800 Subject: [PATCH 0757/1089] more --- python/sglang/srt/managers/expert_location.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 92bf2029ab5..f6d1198a762 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -13,6 +13,7 @@ @dataclass class ExpertLocationMetadata: + ep_size: int physical_to_logical_map: torch.Tensor # (layers, num_physical_experts) logical_to_all_physical_map: torch.Tensor # (layers, num_logical_experts, X) logical_to_all_physical_map_num_valid: torch.Tensor # (layers, num_logical_experts) @@ -67,6 +68,7 @@ def init_by_mapping(server_args: ServerArgs, physical_to_logical_map): ) return ExpertLocationMetadata( + ep_size=TODO, physical_to_logical_map=physical_to_logical_map, logical_to_all_physical_map=logical_to_all_physical_map, logical_to_all_physical_map_num_valid=_compute_logical_to_all_physical_map_num_valid( @@ -95,6 +97,7 @@ def init_by_eplb(server_args: ServerArgs, logical_count: torch.Tensor): ) return ExpertLocationMetadata( + ep_size=TODO, physical_to_logical_map=physical_to_logical_map, logical_to_all_physical_map=logical_to_all_physical_map, logical_to_all_physical_map_num_valid=_compute_logical_to_all_physical_map_num_valid( From e151448a0ab55c669587f1555543c7ce37320a8f Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:06:57 +0800 Subject: [PATCH 0758/1089] more --- python/sglang/srt/managers/expert_location.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index f6d1198a762..ae9a557ce48 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -26,9 +26,15 @@ class ExpertLocationMetadata: def num_layers(self) -> int: return self.physical_to_logical_map.shape[0] + @property + def num_physical_experts(self) -> int: + return self.physical_to_logical_map.shape[1] + @property def num_local_physical_experts(self) -> int: - return TODO + ans, remainder = divmod(self.num_physical_experts, self.ep_size) + assert remainder == 0 + return ans @property def num_logical_experts(self) -> int: From 933d558a233f7657fbf787ef15c462f8702108fb Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:09:55 +0800 Subject: [PATCH 0759/1089] more --- python/sglang/srt/managers/expert_location.py | 37 ++++++++++++------- 1 file changed, 23 insertions(+), 14 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index ae9a557ce48..c3c14e7c15e 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -73,16 +73,11 @@ def init_by_mapping(server_args: ServerArgs, physical_to_logical_map): num_logical_experts=model_config_for_expert_location.num_logical_experts, ) - return ExpertLocationMetadata( + return ExpertLocationMetadata._init_raw( ep_size=TODO, + num_gpus=common["world_size"], physical_to_logical_map=physical_to_logical_map, logical_to_all_physical_map=logical_to_all_physical_map, - logical_to_all_physical_map_num_valid=_compute_logical_to_all_physical_map_num_valid( - logical_to_all_physical_map), - logical_to_rank_dispatch_physical_map=_compute_logical_to_rank_dispatch_physical_map( - logical_to_all_physical_map, - num_gpus=common["world_size"], - ), ) @staticmethod @@ -102,16 +97,11 @@ def init_by_eplb(server_args: ServerArgs, logical_count: torch.Tensor): ) ) - return ExpertLocationMetadata( + return ExpertLocationMetadata._init_raw( ep_size=TODO, + num_gpus=common["world_size"], physical_to_logical_map=physical_to_logical_map, logical_to_all_physical_map=logical_to_all_physical_map, - logical_to_all_physical_map_num_valid=_compute_logical_to_all_physical_map_num_valid( - logical_to_all_physical_map), - logical_to_rank_dispatch_physical_map=_compute_logical_to_rank_dispatch_physical_map( - logical_to_all_physical_map, - num_gpus=common["world_size"], - ), ) @staticmethod @@ -137,6 +127,25 @@ def _init_common(server_args: ServerArgs): world_size=world_size, ) + @staticmethod + def _init_raw( + ep_size: int, + num_gpus: int, + physical_to_logical_map: torch.Tensor, + logical_to_all_physical_map: torch.Tensor, + ): + return ExpertLocationMetadata( + ep_size=ep_size, + physical_to_logical_map=physical_to_logical_map, + logical_to_all_physical_map=logical_to_all_physical_map, + logical_to_all_physical_map_num_valid=_compute_logical_to_all_physical_map_num_valid( + logical_to_all_physical_map), + logical_to_rank_dispatch_physical_map=_compute_logical_to_rank_dispatch_physical_map( + logical_to_all_physical_map, + num_gpus=num_gpus, + ), + ) + def update(self, other: "ExpertLocationMetadata"): # TODO remove # for field in [ From 12aa2edf9d520e31e99f0efe300ec3244bb959c0 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:10:16 +0800 Subject: [PATCH 0760/1089] more --- python/sglang/srt/managers/expert_location.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index c3c14e7c15e..c232f34f8b8 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -115,8 +115,7 @@ def _init_common(server_args: ServerArgs): model_config_for_expert_location.num_logical_experts + server_args.ep_num_redundant_experts ) - # TODO consider case when DP attention is disabled and DP > 1 - world_size = server_args.tp_size + world_size = server_args.ep_size assert num_physical_experts % world_size == 0 num_local_physical_experts = num_physical_experts // world_size From 281e31d090173810641dc2649fa9307dde40692d Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:10:34 +0800 Subject: [PATCH 0761/1089] more --- python/sglang/srt/managers/expert_location.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index c232f34f8b8..7a07cdf97b9 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -75,7 +75,7 @@ def init_by_mapping(server_args: ServerArgs, physical_to_logical_map): return ExpertLocationMetadata._init_raw( ep_size=TODO, - num_gpus=common["world_size"], + num_gpus=common["ep_size"], physical_to_logical_map=physical_to_logical_map, logical_to_all_physical_map=logical_to_all_physical_map, ) @@ -93,13 +93,13 @@ def init_by_eplb(server_args: ServerArgs, logical_count: torch.Tensor): num_replicas=common["num_physical_experts"], num_groups=model_config_for_expert_location.num_groups, num_nodes=server_args.nnodes, - num_gpus=common["world_size"], + num_gpus=common["ep_size"], ) ) return ExpertLocationMetadata._init_raw( ep_size=TODO, - num_gpus=common["world_size"], + num_gpus=common["ep_size"], physical_to_logical_map=physical_to_logical_map, logical_to_all_physical_map=logical_to_all_physical_map, ) @@ -115,15 +115,15 @@ def _init_common(server_args: ServerArgs): model_config_for_expert_location.num_logical_experts + server_args.ep_num_redundant_experts ) - world_size = server_args.ep_size - assert num_physical_experts % world_size == 0 - num_local_physical_experts = num_physical_experts // world_size + ep_size = server_args.ep_size + assert num_physical_experts % ep_size == 0 + num_local_physical_experts = num_physical_experts // ep_size return dict( model_config_for_expert_location=model_config_for_expert_location, num_physical_experts=num_physical_experts, num_local_physical_experts=num_local_physical_experts, - world_size=world_size, + ep_size=ep_size, ) @staticmethod From 244cf7930ae995b1ee2ca829982810e1f2d2700b Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:10:47 +0800 Subject: [PATCH 0762/1089] more --- python/sglang/srt/managers/expert_location.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 7a07cdf97b9..9632f5f5018 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -74,8 +74,7 @@ def init_by_mapping(server_args: ServerArgs, physical_to_logical_map): ) return ExpertLocationMetadata._init_raw( - ep_size=TODO, - num_gpus=common["ep_size"], + ep_size=common["ep_size"], physical_to_logical_map=physical_to_logical_map, logical_to_all_physical_map=logical_to_all_physical_map, ) @@ -98,8 +97,7 @@ def init_by_eplb(server_args: ServerArgs, logical_count: torch.Tensor): ) return ExpertLocationMetadata._init_raw( - ep_size=TODO, - num_gpus=common["ep_size"], + ep_size=common["ep_size"], physical_to_logical_map=physical_to_logical_map, logical_to_all_physical_map=logical_to_all_physical_map, ) @@ -129,7 +127,6 @@ def _init_common(server_args: ServerArgs): @staticmethod def _init_raw( ep_size: int, - num_gpus: int, physical_to_logical_map: torch.Tensor, logical_to_all_physical_map: torch.Tensor, ): @@ -141,7 +138,7 @@ def _init_raw( logical_to_all_physical_map), logical_to_rank_dispatch_physical_map=_compute_logical_to_rank_dispatch_physical_map( logical_to_all_physical_map, - num_gpus=num_gpus, + num_gpus=ep_size, ), ) From de4de7df2609de96555e011327acf8ebb7388f23 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:11:41 +0800 Subject: [PATCH 0763/1089] more --- python/sglang/srt/managers/expert_location.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 9632f5f5018..039787a629b 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -143,13 +143,10 @@ def _init_raw( ) def update(self, other: "ExpertLocationMetadata"): - # TODO remove - # for field in [ - # "num_layers", - # "num_local_physical_experts", - # "num_logical_experts", - # ]: - # assert getattr(self, field) == getattr(other, field) + for field in [ + "ep_size", + ]: + assert getattr(self, field) == getattr(other, field) for field in [ "physical_to_logical_map", From 4066b051620602730dc460f9bf05a63e534c2627 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:12:11 +0800 Subject: [PATCH 0764/1089] more --- python/sglang/srt/managers/expert_location.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 039787a629b..191ee815abb 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -150,18 +150,13 @@ def update(self, other: "ExpertLocationMetadata"): for field in [ "physical_to_logical_map", + "logical_to_all_physical_map", "logical_to_all_physical_map_num_valid", "logical_to_rank_dispatch_physical_map", ]: # Cannot update address to avoid breaking CUDA graph getattr(self, field)[...] = getattr(other, field) - for field in [ - # TODO maybe make last dim size const to simplify the logic - "logical_to_all_physical_map", - ]: - setattr(self, field, getattr(other, field).to(getattr(self, field).device)) - def to(self, device): for field in [ "logical_to_all_physical_map", From e8420be075a991585a791d8adcfe288a0fac39ac Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:12:54 +0800 Subject: [PATCH 0765/1089] more --- python/sglang/srt/managers/expert_location.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 191ee815abb..c653a65129b 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -130,12 +130,14 @@ def _init_raw( physical_to_logical_map: torch.Tensor, logical_to_all_physical_map: torch.Tensor, ): + logical_to_all_physical_map_padded = TODO + logical_to_all_physical_map_num_valid = torch.count_nonzero(logical_to_all_physical_map != -1, dim=-1) + return ExpertLocationMetadata( ep_size=ep_size, physical_to_logical_map=physical_to_logical_map, - logical_to_all_physical_map=logical_to_all_physical_map, - logical_to_all_physical_map_num_valid=_compute_logical_to_all_physical_map_num_valid( - logical_to_all_physical_map), + logical_to_all_physical_map=logical_to_all_physical_map_padded, + logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid, logical_to_rank_dispatch_physical_map=_compute_logical_to_rank_dispatch_physical_map( logical_to_all_physical_map, num_gpus=ep_size, @@ -236,10 +238,6 @@ def _pad_nested_array(arr, pad_value): return padded -def _compute_logical_to_all_physical_map_num_valid(logical_to_all_physical_map): - return torch.count_nonzero(logical_to_all_physical_map != -1, dim=-1) - - def _compute_logical_to_rank_dispatch_physical_map( logical_to_all_physical_map: torch.Tensor, num_gpus: int, From 58b174e5ade6ee84e2d220c16093216f4dad8d50 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:13:19 +0800 Subject: [PATCH 0766/1089] more --- python/sglang/srt/managers/expert_location.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index c653a65129b..05de56e38ce 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -5,6 +5,7 @@ from typing import List, Optional import torch +import torch.nn.functional as F from sglang.srt.configs.model_config import ModelConfig from sglang.srt.managers import deepseek_eplb from sglang.srt.model_loader import get_model_architecture @@ -130,7 +131,9 @@ def _init_raw( physical_to_logical_map: torch.Tensor, logical_to_all_physical_map: torch.Tensor, ): - logical_to_all_physical_map_padded = TODO + logical_to_all_physical_map_padded = F.pad(logical_to_all_physical_map, + (0, num_physical_experts - current_last_dim), value=-1) + logical_to_all_physical_map_num_valid = torch.count_nonzero(logical_to_all_physical_map != -1, dim=-1) return ExpertLocationMetadata( From b47b22792148c4bd87c51ecd1acefe35a73798b2 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:13:40 +0800 Subject: [PATCH 0767/1089] more --- python/sglang/srt/managers/expert_location.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 05de56e38ce..b4c09fc2b37 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -131,8 +131,13 @@ def _init_raw( physical_to_logical_map: torch.Tensor, logical_to_all_physical_map: torch.Tensor, ): - logical_to_all_physical_map_padded = F.pad(logical_to_all_physical_map, - (0, num_physical_experts - current_last_dim), value=-1) + _, num_physical_experts = physical_to_logical_map.shape + + logical_to_all_physical_map_padded = F.pad( + logical_to_all_physical_map, + (0, num_physical_experts - logical_to_all_physical_map.shape[-1]), + value=-1, + ) logical_to_all_physical_map_num_valid = torch.count_nonzero(logical_to_all_physical_map != -1, dim=-1) From 8d4e7375d23a99aec5204166a531790999afca3b Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:14:10 +0800 Subject: [PATCH 0768/1089] more --- python/sglang/srt/managers/expert_location.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index b4c09fc2b37..dfada125632 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -41,6 +41,9 @@ def num_local_physical_experts(self) -> int: def num_logical_experts(self) -> int: return self.logical_to_all_physical_map.shape[1] + def __post_init__(self): + TODO + # -------------------------------- construction and mutation ------------------------------------ @staticmethod From 5504b9a2de3c2a3bde92a4c9ce2cb7da164467f1 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:15:41 +0800 Subject: [PATCH 0769/1089] more --- python/sglang/srt/managers/expert_location.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index dfada125632..106ac56bd7a 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -42,7 +42,13 @@ def num_logical_experts(self) -> int: return self.logical_to_all_physical_map.shape[1] def __post_init__(self): - TODO + num_layers_0, num_physical_experts_0 = self.physical_to_logical_map.shape + num_layers_1, num_logical_experts_0, num_physical_experts_1 = self.logical_to_all_physical_map.shape + num_layers_2, num_logical_experts_1 = self.logical_to_all_physical_map_num_valid.shape + ep_size_0, num_layers_3, num_logical_experts_2 = self.logical_to_rank_dispatch_physical_map.shape + assert num_layers_0 == num_layers_1 == num_layers_2 == num_layers_3 + assert num_logical_experts_0 == num_logical_experts_1 == num_logical_experts_2 + assert num_physical_experts_0 == num_physical_experts_1 # -------------------------------- construction and mutation ------------------------------------ From 0e9cd9bb9d2091d59970a0dbe72b950dff26a4e9 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:16:13 +0800 Subject: [PATCH 0770/1089] more --- python/sglang/srt/managers/expert_location.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 106ac56bd7a..b1d16e3482f 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -14,7 +14,6 @@ @dataclass class ExpertLocationMetadata: - ep_size: int physical_to_logical_map: torch.Tensor # (layers, num_physical_experts) logical_to_all_physical_map: torch.Tensor # (layers, num_logical_experts, X) logical_to_all_physical_map_num_valid: torch.Tensor # (layers, num_logical_experts) @@ -41,6 +40,10 @@ def num_local_physical_experts(self) -> int: def num_logical_experts(self) -> int: return self.logical_to_all_physical_map.shape[1] + @property + def ep_size(self): + return self.logical_to_rank_dispatch_physical_map.shape[0] + def __post_init__(self): num_layers_0, num_physical_experts_0 = self.physical_to_logical_map.shape num_layers_1, num_logical_experts_0, num_physical_experts_1 = self.logical_to_all_physical_map.shape @@ -151,7 +154,6 @@ def _init_raw( logical_to_all_physical_map_num_valid = torch.count_nonzero(logical_to_all_physical_map != -1, dim=-1) return ExpertLocationMetadata( - ep_size=ep_size, physical_to_logical_map=physical_to_logical_map, logical_to_all_physical_map=logical_to_all_physical_map_padded, logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid, From 7decf569a920b6b83a641920d8c43a1e893996a7 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:16:50 +0800 Subject: [PATCH 0771/1089] more --- python/sglang/srt/managers/expert_location.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index b1d16e3482f..d118616dd4c 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -53,7 +53,7 @@ def __post_init__(self): assert num_logical_experts_0 == num_logical_experts_1 == num_logical_experts_2 assert num_physical_experts_0 == num_physical_experts_1 - # -------------------------------- construction and mutation ------------------------------------ + # -------------------------------- construction ------------------------------------ @staticmethod def init_trivial(server_args: ServerArgs): @@ -163,6 +163,8 @@ def _init_raw( ), ) + # -------------------------------- mutation ------------------------------------ + def update(self, other: "ExpertLocationMetadata"): for field in [ "ep_size", From 6664524c13a7874af3cc0b271f57e3549bc59173 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:17:53 +0800 Subject: [PATCH 0772/1089] more --- python/sglang/srt/server_args.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index dc92d2cd7a6..afe93436c69 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -163,6 +163,7 @@ class ServerArgs: enable_deepep_moe: bool = False deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto" ep_num_redundant_experts: int = 0 + ep_dispatch_algorithm: Optional[Literal["static", "random"]] = None init_expert_location: Optional[str] = None enable_eplb: bool = False eplb_storage_dir: str = "/tmp/eplb_storage" From 710f1472b17a048a30e6a261622e606f7ec9cd27 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:18:38 +0800 Subject: [PATCH 0773/1089] more --- python/sglang/srt/server_args.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index afe93436c69..c877b7bbbfe 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1134,6 +1134,12 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.ep_num_redundant_experts, help="Allocate this number of redundant experts in expert parallel.", ) + parser.add_argument( + "--ep-dispatch-algorithm", + type=str, + default=ServerArgs.ep_dispatch_algorithm, + help="The algorithm to choose ranks for redundant experts in expert parallel.", + ) parser.add_argument( "--init-expert-location", type=str, From c93ecdc90ced421bf8aeebf867261aa13e519da7 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:19:25 +0800 Subject: [PATCH 0774/1089] more --- python/sglang/srt/server_args.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index c877b7bbbfe..dcbb8d050a4 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -341,6 +341,10 @@ def __post_init__(self): logger.info( f"EPLB is enabled. The enable_scheduler_input_blocker and enable_expert_distribution_recorder are automatically enabled." ) + if self.enable_eplb or (self.init_expert_location is not None): + self.ep_dispatch_algorithm = "static" + logger.info( + f"EPLB is enabled or init_expert_location is provided. ep_dispatch_algorithm is set to `static`.") if self.ep_num_redundant_experts > 0: assert ( From dd936a42113df09a8fab4195f608629ddea3e472 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:20:15 +0800 Subject: [PATCH 0775/1089] more --- python/sglang/srt/managers/tokenizer_manager.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 96b3285dd9c..378c0b93479 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -45,7 +45,6 @@ import zmq import zmq.asyncio from fastapi import BackgroundTasks - from sglang.srt.aio_rwlock import RWLock from sglang.srt.configs.model_config import ModelConfig from sglang.srt.disaggregation.conn import KVBootstrapServer @@ -706,8 +705,8 @@ async def eplb_save_expert_distribution(self): async def update_expert_location(self, obj: UpdateExpertLocationReqInput): self.auto_create_handle_loop() assert ( - self.server_args.enable_scheduler_input_blocker - ), f"update_expert_location requires --enable-scheduler-input-blocker" + self.server_args.enable_scheduler_input_blocker and (self.server_args.ep_dispatch_algorithm is not None) + ), f"update_expert_location requires enable_scheduler_input_blocker and non-null ep_dispatch_algorithm" self.expert_location_metadata = None @@ -1026,8 +1025,8 @@ def _handle_batch_output( elif isinstance(recv_obj, BatchTokenIDOut): if self.server_args.stream_output and state.obj.stream: output_token_ids = recv_obj.output_ids[i][ - state.last_output_offset : - ] + state.last_output_offset: + ] state.last_output_offset = len(recv_obj.output_ids[i]) else: output_token_ids = recv_obj.output_ids[i] From b2f073de2f9d428651ea254a2573fda34e3a36b3 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:20:36 +0800 Subject: [PATCH 0776/1089] empty --- python/sglang/srt/layers/moe/expert_location_dispatch.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 python/sglang/srt/layers/moe/expert_location_dispatch.py diff --git a/python/sglang/srt/layers/moe/expert_location_dispatch.py b/python/sglang/srt/layers/moe/expert_location_dispatch.py new file mode 100644 index 00000000000..e69de29bb2d From 5b94e196d05226e4c23548b585caa01e28049a2f Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:21:08 +0800 Subject: [PATCH 0777/1089] more --- .../srt/layers/moe/expert_location_dispatch.py | 16 ++++++++++++++++ python/sglang/srt/layers/moe/topk.py | 16 ---------------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/python/sglang/srt/layers/moe/expert_location_dispatch.py b/python/sglang/srt/layers/moe/expert_location_dispatch.py index e69de29bb2d..6e82b424eca 100644 --- a/python/sglang/srt/layers/moe/expert_location_dispatch.py +++ b/python/sglang/srt/layers/moe/expert_location_dispatch.py @@ -0,0 +1,16 @@ +@torch.compile(dynamic=True, backend=get_compiler_backend()) +def _hack_expert_location_dispatch_random( + topk_ids: torch.Tensor, + expert_logical_to_all_physical_map: torch.Tensor, + expert_logical_to_all_physical_map_num_valid: torch.Tensor, +): + topk_ids_original_shape = topk_ids.shape + device = topk_ids.device + topk_ids = topk_ids.flatten() + + chosen_dispatch_index = (torch.randint(0, 65536, topk_ids.shape, dtype=torch.int32, device=device) + % expert_logical_to_all_physical_map_num_valid[topk_ids]) + topk_ids = expert_logical_to_all_physical_map[topk_ids, chosen_dispatch_index] + + topk_ids = topk_ids.view(topk_ids_original_shape) + return topk_ids diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 2061d5f36f3..1fe6032beb7 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -316,19 +316,3 @@ def select_experts( return topk_weights, topk_ids -@torch.compile(dynamic=True, backend=get_compiler_backend()) -def _hack_expert_location_dispatch_random( - topk_ids: torch.Tensor, - expert_logical_to_all_physical_map: torch.Tensor, - expert_logical_to_all_physical_map_num_valid: torch.Tensor, -): - topk_ids_original_shape = topk_ids.shape - device = topk_ids.device - topk_ids = topk_ids.flatten() - - chosen_dispatch_index = (torch.randint(0, 65536, topk_ids.shape, dtype=torch.int32, device=device) - % expert_logical_to_all_physical_map_num_valid[topk_ids]) - topk_ids = expert_logical_to_all_physical_map[topk_ids, chosen_dispatch_index] - - topk_ids = topk_ids.view(topk_ids_original_shape) - return topk_ids From b23d14d3952cf63c8a02760a557bca9b6bd6f39d Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:21:26 +0800 Subject: [PATCH 0778/1089] more --- .../sglang/srt/layers/moe/expert_location_dispatch.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/python/sglang/srt/layers/moe/expert_location_dispatch.py b/python/sglang/srt/layers/moe/expert_location_dispatch.py index 6e82b424eca..6e1ab64d6b5 100644 --- a/python/sglang/srt/layers/moe/expert_location_dispatch.py +++ b/python/sglang/srt/layers/moe/expert_location_dispatch.py @@ -1,3 +1,14 @@ +from dataclasses import dataclass + +import torch +from sglang.srt.utils import get_compiler_backend + + +@dataclass +class ExpertLocationDispatchInfo: + pass + + @torch.compile(dynamic=True, backend=get_compiler_backend()) def _hack_expert_location_dispatch_random( topk_ids: torch.Tensor, From e5e6008472dad3470ce0142b8fe22f4b004ff8da Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:21:55 +0800 Subject: [PATCH 0779/1089] more --- python/sglang/srt/layers/moe/expert_location_dispatch.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/moe/expert_location_dispatch.py b/python/sglang/srt/layers/moe/expert_location_dispatch.py index 6e1ab64d6b5..dd68852966f 100644 --- a/python/sglang/srt/layers/moe/expert_location_dispatch.py +++ b/python/sglang/srt/layers/moe/expert_location_dispatch.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from typing import Literal import torch from sglang.srt.utils import get_compiler_backend @@ -6,7 +7,7 @@ @dataclass class ExpertLocationDispatchInfo: - pass + ep_dispatch_algorithm: Literal["static", "random"] @torch.compile(dynamic=True, backend=get_compiler_backend()) From 050e9904469ff6de79442b8418ad34c305f1cd5d Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:22:19 +0800 Subject: [PATCH 0780/1089] more --- python/sglang/srt/layers/moe/expert_location_dispatch.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/moe/expert_location_dispatch.py b/python/sglang/srt/layers/moe/expert_location_dispatch.py index dd68852966f..cd67c839bfc 100644 --- a/python/sglang/srt/layers/moe/expert_location_dispatch.py +++ b/python/sglang/srt/layers/moe/expert_location_dispatch.py @@ -10,7 +10,10 @@ class ExpertLocationDispatchInfo: ep_dispatch_algorithm: Literal["static", "random"] -@torch.compile(dynamic=True, backend=get_compiler_backend()) +def topk_ids_logical_to_physical(topk_ids: torch.Tensor, info: ExpertLocationDispatchInfo): + return TODO + + def _hack_expert_location_dispatch_random( topk_ids: torch.Tensor, expert_logical_to_all_physical_map: torch.Tensor, From 0560cefcde23ca27b2db73c4515ef0ef549d7fd0 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:22:33 +0800 Subject: [PATCH 0781/1089] more --- .../srt/layers/moe/expert_location_dispatch.py | 14 +++++++++++++- python/sglang/srt/layers/moe/topk.py | 11 ----------- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/python/sglang/srt/layers/moe/expert_location_dispatch.py b/python/sglang/srt/layers/moe/expert_location_dispatch.py index cd67c839bfc..1c27e22614f 100644 --- a/python/sglang/srt/layers/moe/expert_location_dispatch.py +++ b/python/sglang/srt/layers/moe/expert_location_dispatch.py @@ -10,9 +10,21 @@ class ExpertLocationDispatchInfo: ep_dispatch_algorithm: Literal["static", "random"] -def topk_ids_logical_to_physical(topk_ids: torch.Tensor, info: ExpertLocationDispatchInfo): +def topk_ids_logical_to_physical(topk_ids: torch.Tensor, info: ExpertLocationDispatchInfo) -> torch.Tensor: return TODO + # TODO + if expert_logical_to_rank_dispatch_physical_map is not None: + # TODO optimize these things later + if forward_mode.is_extend(): + topk_ids = _hack_expert_location_dispatch_random( + topk_ids=topk_ids, + expert_logical_to_all_physical_map=expert_logical_to_all_physical_map, + expert_logical_to_all_physical_map_num_valid=expert_logical_to_all_physical_map_num_valid, + ) + else: + topk_ids = expert_logical_to_rank_dispatch_physical_map[topk_ids] + def _hack_expert_location_dispatch_random( topk_ids: torch.Tensor, diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 1fe6032beb7..9a656393e80 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -300,17 +300,6 @@ def select_experts( renormalize=renormalize, ) - if expert_logical_to_rank_dispatch_physical_map is not None: - # TODO optimize these things later - if forward_mode.is_extend(): - topk_ids = _hack_expert_location_dispatch_random( - topk_ids=topk_ids, - expert_logical_to_all_physical_map=expert_logical_to_all_physical_map, - expert_logical_to_all_physical_map_num_valid=expert_logical_to_all_physical_map_num_valid, - ) - else: - topk_ids = expert_logical_to_rank_dispatch_physical_map[topk_ids] - get_global_expert_distribution_recorder().on_select_experts(topk_ids=topk_ids) return topk_weights, topk_ids From 84e5ab7b6b715cfc96c2bfb8a34aae3aa8173529 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:22:45 +0800 Subject: [PATCH 0782/1089] more --- python/sglang/srt/layers/moe/expert_location_dispatch.py | 6 ++++++ python/sglang/srt/layers/moe/topk.py | 4 ---- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/layers/moe/expert_location_dispatch.py b/python/sglang/srt/layers/moe/expert_location_dispatch.py index 1c27e22614f..2e7824585b1 100644 --- a/python/sglang/srt/layers/moe/expert_location_dispatch.py +++ b/python/sglang/srt/layers/moe/expert_location_dispatch.py @@ -13,6 +13,12 @@ class ExpertLocationDispatchInfo: def topk_ids_logical_to_physical(topk_ids: torch.Tensor, info: ExpertLocationDispatchInfo) -> torch.Tensor: return TODO + # TODO + # expert_logical_to_rank_dispatch_physical_map: Optional[torch.Tensor] = None, + # forward_mode=None, + # expert_logical_to_all_physical_map=None, + # expert_logical_to_all_physical_map_num_valid=None, + # TODO if expert_logical_to_rank_dispatch_physical_map is not None: # TODO optimize these things later diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 9a656393e80..7736dc6e9aa 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -245,10 +245,6 @@ def select_experts( custom_routing_function: Optional[Callable] = None, correction_bias: Optional[torch.Tensor] = None, torch_native: bool = False, - expert_logical_to_rank_dispatch_physical_map: Optional[torch.Tensor] = None, - forward_mode=None, - expert_logical_to_all_physical_map=None, - expert_logical_to_all_physical_map_num_valid=None, ): n_share_experts_fusion = 0 if global_server_args_dict["n_share_experts_fusion"] is not None: From 79ac61fff5dbfb731a37a471599a9d68301512e7 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:23:06 +0800 Subject: [PATCH 0783/1089] more --- python/sglang/srt/layers/moe/expert_location_dispatch.py | 4 ++-- python/sglang/srt/layers/moe/topk.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/layers/moe/expert_location_dispatch.py b/python/sglang/srt/layers/moe/expert_location_dispatch.py index 2e7824585b1..6822cc29774 100644 --- a/python/sglang/srt/layers/moe/expert_location_dispatch.py +++ b/python/sglang/srt/layers/moe/expert_location_dispatch.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Literal +from typing import Literal, Optional import torch from sglang.srt.utils import get_compiler_backend @@ -10,7 +10,7 @@ class ExpertLocationDispatchInfo: ep_dispatch_algorithm: Literal["static", "random"] -def topk_ids_logical_to_physical(topk_ids: torch.Tensor, info: ExpertLocationDispatchInfo) -> torch.Tensor: +def topk_ids_logical_to_physical(topk_ids: torch.Tensor, info: Optional[ExpertLocationDispatchInfo]) -> torch.Tensor: return TODO # TODO diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 7736dc6e9aa..bb5c273ac0b 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -16,6 +16,7 @@ import torch import torch.nn.functional as F +from sglang.srt.layers.moe.expert_location_dispatch import ExpertLocationDispatchInfo from sglang.srt.managers.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip, get_bool_env_var @@ -245,6 +246,7 @@ def select_experts( custom_routing_function: Optional[Callable] = None, correction_bias: Optional[torch.Tensor] = None, torch_native: bool = False, + expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, ): n_share_experts_fusion = 0 if global_server_args_dict["n_share_experts_fusion"] is not None: @@ -299,5 +301,3 @@ def select_experts( get_global_expert_distribution_recorder().on_select_experts(topk_ids=topk_ids) return topk_weights, topk_ids - - From befb9cb46748c9f7d00cd282b8cfec6153a89939 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:23:19 +0800 Subject: [PATCH 0784/1089] more --- python/sglang/srt/layers/moe/topk.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index bb5c273ac0b..374e6c9f162 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -264,8 +264,10 @@ def select_experts( num_expert_group=num_expert_group, topk_group=topk_group, n_share_experts_fusion=n_share_experts_fusion, + expert_location_dispatch_info=expert_location_dispatch_info, ) else: + assert expert_location_dispatch_info is None topk_weights, topk_ids = biased_grouped_topk( hidden_states=hidden_states, gating_output=router_logits, @@ -277,6 +279,7 @@ def select_experts( n_share_experts_fusion=n_share_experts_fusion, ) elif torch_native and custom_routing_function is None: + assert expert_location_dispatch_info is None topk_weights, topk_ids = fused_topk_native( hidden_states=hidden_states, gating_output=router_logits, @@ -284,6 +287,7 @@ def select_experts( renormalize=renormalize, ) elif custom_routing_function is None: + assert expert_location_dispatch_info is None topk_weights, topk_ids = fused_topk( hidden_states=hidden_states, gating_output=router_logits, @@ -291,6 +295,7 @@ def select_experts( renormalize=renormalize, ) else: + assert expert_location_dispatch_info is None topk_weights, topk_ids = custom_routing_function( hidden_states=hidden_states, gating_output=router_logits, From af7fb7aa470e8e3dedbeda08acbeefa54a3759fd Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:23:28 +0800 Subject: [PATCH 0785/1089] more --- python/sglang/srt/layers/moe/topk.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 374e6c9f162..0fcac22eee5 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -101,6 +101,7 @@ def grouped_topk( num_expert_group: int = 0, topk_group: int = 0, n_share_experts_fusion: int = 0, + expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, ): assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" From 30972bfcbea3ad7185c7126ed057a9c6297b0f20 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:23:38 +0800 Subject: [PATCH 0786/1089] more --- python/sglang/srt/layers/moe/topk.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 0fcac22eee5..bc6a6902503 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -143,7 +143,9 @@ def grouped_topk( ) topk_weights = topk_weights / topk_weights_sum - return topk_weights.to(torch.float32), topk_ids.to(torch.int32) + topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32) + + return topk_weights, topk_ids def biased_grouped_topk_impl( From 3f0301d67d3db2c5d00137cfa8649b8b0537f122 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:23:54 +0800 Subject: [PATCH 0787/1089] more --- python/sglang/srt/layers/moe/topk.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index bc6a6902503..eea9a3465c2 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -16,7 +16,7 @@ import torch import torch.nn.functional as F -from sglang.srt.layers.moe.expert_location_dispatch import ExpertLocationDispatchInfo +from sglang.srt.layers.moe.expert_location_dispatch import ExpertLocationDispatchInfo, topk_ids_logical_to_physical from sglang.srt.managers.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip, get_bool_env_var @@ -144,6 +144,8 @@ def grouped_topk( topk_weights = topk_weights / topk_weights_sum topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32) + + topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info) return topk_weights, topk_ids From c7a73a5c13f0c92bc4528f921b27c92b4a4bc8be Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:24:03 +0800 Subject: [PATCH 0788/1089] more --- python/sglang/srt/layers/moe/expert_location_dispatch.py | 3 +++ python/sglang/srt/layers/moe/topk.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/moe/expert_location_dispatch.py b/python/sglang/srt/layers/moe/expert_location_dispatch.py index 6822cc29774..d7fbad97f7c 100644 --- a/python/sglang/srt/layers/moe/expert_location_dispatch.py +++ b/python/sglang/srt/layers/moe/expert_location_dispatch.py @@ -11,6 +11,9 @@ class ExpertLocationDispatchInfo: def topk_ids_logical_to_physical(topk_ids: torch.Tensor, info: Optional[ExpertLocationDispatchInfo]) -> torch.Tensor: + if info is None: + return topk_ids + return TODO # TODO diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index eea9a3465c2..3464bb83104 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -144,7 +144,7 @@ def grouped_topk( topk_weights = topk_weights / topk_weights_sum topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32) - + topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info) return topk_weights, topk_ids From a71b514d6ea15fa172c1b86426afbe8b44ef6340 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:24:53 +0800 Subject: [PATCH 0789/1089] more --- .../srt/layers/moe/expert_location_dispatch.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/python/sglang/srt/layers/moe/expert_location_dispatch.py b/python/sglang/srt/layers/moe/expert_location_dispatch.py index d7fbad97f7c..edb87d87d78 100644 --- a/python/sglang/srt/layers/moe/expert_location_dispatch.py +++ b/python/sglang/srt/layers/moe/expert_location_dispatch.py @@ -14,6 +14,9 @@ def topk_ids_logical_to_physical(topk_ids: torch.Tensor, info: Optional[ExpertLo if info is None: return topk_ids + if info.ep_dispatch_algorithm == "static": + return TODO + return TODO # TODO @@ -35,6 +38,16 @@ def topk_ids_logical_to_physical(topk_ids: torch.Tensor, info: Optional[ExpertLo topk_ids = expert_logical_to_rank_dispatch_physical_map[topk_ids] +def _topk_ids_logical_to_physical_static(topk_ids: torch.Tensor, + info: Optional[ExpertLocationDispatchInfo]) -> torch.Tensor: + return TODO + + +def _topk_ids_logical_to_physical_random(topk_ids: torch.Tensor, + info: Optional[ExpertLocationDispatchInfo]) -> torch.Tensor: + return TODO + + def _hack_expert_location_dispatch_random( topk_ids: torch.Tensor, expert_logical_to_all_physical_map: torch.Tensor, From ee8b5791012ade9531610782eeebd6910ca16414 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:25:08 +0800 Subject: [PATCH 0790/1089] more --- .../srt/layers/moe/expert_location_dispatch.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/layers/moe/expert_location_dispatch.py b/python/sglang/srt/layers/moe/expert_location_dispatch.py index edb87d87d78..3784cc9fe10 100644 --- a/python/sglang/srt/layers/moe/expert_location_dispatch.py +++ b/python/sglang/srt/layers/moe/expert_location_dispatch.py @@ -15,15 +15,10 @@ def topk_ids_logical_to_physical(topk_ids: torch.Tensor, info: Optional[ExpertLo return topk_ids if info.ep_dispatch_algorithm == "static": - return TODO - - return TODO - - # TODO - # expert_logical_to_rank_dispatch_physical_map: Optional[torch.Tensor] = None, - # forward_mode=None, - # expert_logical_to_all_physical_map=None, - # expert_logical_to_all_physical_map_num_valid=None, + return _topk_ids_logical_to_physical_static(topk_ids, info) + if info.ep_dispatch_algorithm == "random": + return _topk_ids_logical_to_physical_random(topk_ids, info) + raise NotImplementedError # TODO if expert_logical_to_rank_dispatch_physical_map is not None: From 0ad92becad51bc107f934338d08b1b3f1bab5802 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:25:23 +0800 Subject: [PATCH 0791/1089] more --- .../layers/moe/expert_location_dispatch.py | 22 +------------------ 1 file changed, 1 insertion(+), 21 deletions(-) diff --git a/python/sglang/srt/layers/moe/expert_location_dispatch.py b/python/sglang/srt/layers/moe/expert_location_dispatch.py index 3784cc9fe10..efddac5b236 100644 --- a/python/sglang/srt/layers/moe/expert_location_dispatch.py +++ b/python/sglang/srt/layers/moe/expert_location_dispatch.py @@ -20,34 +20,14 @@ def topk_ids_logical_to_physical(topk_ids: torch.Tensor, info: Optional[ExpertLo return _topk_ids_logical_to_physical_random(topk_ids, info) raise NotImplementedError - # TODO - if expert_logical_to_rank_dispatch_physical_map is not None: - # TODO optimize these things later - if forward_mode.is_extend(): - topk_ids = _hack_expert_location_dispatch_random( - topk_ids=topk_ids, - expert_logical_to_all_physical_map=expert_logical_to_all_physical_map, - expert_logical_to_all_physical_map_num_valid=expert_logical_to_all_physical_map_num_valid, - ) - else: - topk_ids = expert_logical_to_rank_dispatch_physical_map[topk_ids] - def _topk_ids_logical_to_physical_static(topk_ids: torch.Tensor, info: Optional[ExpertLocationDispatchInfo]) -> torch.Tensor: - return TODO + return expert_logical_to_rank_dispatch_physical_map[topk_ids] def _topk_ids_logical_to_physical_random(topk_ids: torch.Tensor, info: Optional[ExpertLocationDispatchInfo]) -> torch.Tensor: - return TODO - - -def _hack_expert_location_dispatch_random( - topk_ids: torch.Tensor, - expert_logical_to_all_physical_map: torch.Tensor, - expert_logical_to_all_physical_map_num_valid: torch.Tensor, -): topk_ids_original_shape = topk_ids.shape device = topk_ids.device topk_ids = topk_ids.flatten() From 03752bd324ffa435cb693473bf01dee521194cff Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:25:57 +0800 Subject: [PATCH 0792/1089] more --- python/sglang/srt/layers/moe/expert_location_dispatch.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/python/sglang/srt/layers/moe/expert_location_dispatch.py b/python/sglang/srt/layers/moe/expert_location_dispatch.py index efddac5b236..883678a9f9b 100644 --- a/python/sglang/srt/layers/moe/expert_location_dispatch.py +++ b/python/sglang/srt/layers/moe/expert_location_dispatch.py @@ -2,6 +2,7 @@ from typing import Literal, Optional import torch +from sglang.srt.managers.expert_location import ExpertLocationMetadata from sglang.srt.utils import get_compiler_backend @@ -9,6 +10,12 @@ class ExpertLocationDispatchInfo: ep_dispatch_algorithm: Literal["static", "random"] + @classmethod + def init_new(cls, expert_location_metadata: ExpertLocationMetadata, ep_rank: int, layer_id: int): + return cls( + ep_dispatch_algorithm=TODO, + ) + def topk_ids_logical_to_physical(topk_ids: torch.Tensor, info: Optional[ExpertLocationDispatchInfo]) -> torch.Tensor: if info is None: From ef37703c2a7decacea5cd737634a768e722b41a1 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:26:31 +0800 Subject: [PATCH 0793/1089] more --- python/sglang/srt/managers/schedule_batch.py | 1 + python/sglang/srt/model_executor/model_runner.py | 1 + 2 files changed, 2 insertions(+) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 202f685aa04..ad9aae396a0 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -83,6 +83,7 @@ "disable_radix_cache": ServerArgs.disable_radix_cache, "flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged, "moe_dense_tp_size": ServerArgs.moe_dense_tp_size, + "ep_dispatch_algorithm": ServerArgs.ep_dispatch_algorithm, "chunked_prefill_size": ServerArgs.chunked_prefill_size, "n_share_experts_fusion": ServerArgs.n_share_experts_fusion, "disable_shared_experts_fusion": ServerArgs.disable_shared_experts_fusion, diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 63141f3c908..f99e32d3242 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -177,6 +177,7 @@ def __init__( "disable_radix_cache": server_args.disable_radix_cache, "flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged, "moe_dense_tp_size": server_args.moe_dense_tp_size, + "ep_dispatch_algorithm": server_args.ep_dispatch_algorithm, "debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder, "debug_tensor_dump_inject": server_args.debug_tensor_dump_inject, "n_share_experts_fusion": server_args.n_share_experts_fusion, From c1ad30c94579cd9427198715aa610d3a713b51a9 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:26:54 +0800 Subject: [PATCH 0794/1089] more --- python/sglang/srt/layers/moe/expert_location_dispatch.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/moe/expert_location_dispatch.py b/python/sglang/srt/layers/moe/expert_location_dispatch.py index 883678a9f9b..c63b1aea882 100644 --- a/python/sglang/srt/layers/moe/expert_location_dispatch.py +++ b/python/sglang/srt/layers/moe/expert_location_dispatch.py @@ -3,6 +3,7 @@ import torch from sglang.srt.managers.expert_location import ExpertLocationMetadata +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.utils import get_compiler_backend @@ -12,8 +13,12 @@ class ExpertLocationDispatchInfo: @classmethod def init_new(cls, expert_location_metadata: ExpertLocationMetadata, ep_rank: int, layer_id: int): + ep_dispatch_algorithm = global_server_args_dict["ep_dispatch_algorithm"] + if ep_dispatch_algorithm is None: + return None + return cls( - ep_dispatch_algorithm=TODO, + ep_dispatch_algorithm=ep_dispatch_algorithm, ) From 33b312701a92aeb80327c501122a9ed11e4ea14a Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:28:00 +0800 Subject: [PATCH 0795/1089] more --- .../srt/layers/moe/expert_location_dispatch.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/layers/moe/expert_location_dispatch.py b/python/sglang/srt/layers/moe/expert_location_dispatch.py index c63b1aea882..ad4b3bcabbf 100644 --- a/python/sglang/srt/layers/moe/expert_location_dispatch.py +++ b/python/sglang/srt/layers/moe/expert_location_dispatch.py @@ -10,6 +10,9 @@ @dataclass class ExpertLocationDispatchInfo: ep_dispatch_algorithm: Literal["static", "random"] + partial_logical_to_rank_dispatch_physical_map: torch.Tensor + partial_logical_to_all_physical_map: torch.Tensor + partial_logical_to_all_physical_map_num_valid: torch.Tensor @classmethod def init_new(cls, expert_location_metadata: ExpertLocationMetadata, ep_rank: int, layer_id: int): @@ -19,6 +22,11 @@ def init_new(cls, expert_location_metadata: ExpertLocationMetadata, ep_rank: int return cls( ep_dispatch_algorithm=ep_dispatch_algorithm, + partial_logical_to_rank_dispatch_physical_map=expert_location_metadata.logical_to_rank_dispatch_physical_map[ + ep_rank, layer_id, :], + partial_logical_to_all_physical_map=expert_location_metadata.logical_to_all_physical_map[layer_id, :], + partial_logical_to_all_physical_map_num_valid=expert_location_metadata.logical_to_all_physical_map_num_valid[ + layer_id, :], ) @@ -35,7 +43,7 @@ def topk_ids_logical_to_physical(topk_ids: torch.Tensor, info: Optional[ExpertLo def _topk_ids_logical_to_physical_static(topk_ids: torch.Tensor, info: Optional[ExpertLocationDispatchInfo]) -> torch.Tensor: - return expert_logical_to_rank_dispatch_physical_map[topk_ids] + return info.partial_logical_to_rank_dispatch_physical_map[topk_ids] def _topk_ids_logical_to_physical_random(topk_ids: torch.Tensor, @@ -45,8 +53,8 @@ def _topk_ids_logical_to_physical_random(topk_ids: torch.Tensor, topk_ids = topk_ids.flatten() chosen_dispatch_index = (torch.randint(0, 65536, topk_ids.shape, dtype=torch.int32, device=device) - % expert_logical_to_all_physical_map_num_valid[topk_ids]) - topk_ids = expert_logical_to_all_physical_map[topk_ids, chosen_dispatch_index] + % info.partial_logical_to_all_physical_map_num_valid[topk_ids]) + topk_ids = info.partial_logical_to_all_physical_map[topk_ids, chosen_dispatch_index] topk_ids = topk_ids.view(topk_ids_original_shape) return topk_ids From b4bc01f2b3c8d0529ed894effa9f805a6e77268a Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:28:18 +0800 Subject: [PATCH 0796/1089] more --- python/sglang/srt/layers/moe/expert_location_dispatch.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/sglang/srt/layers/moe/expert_location_dispatch.py b/python/sglang/srt/layers/moe/expert_location_dispatch.py index ad4b3bcabbf..4662e5f975f 100644 --- a/python/sglang/srt/layers/moe/expert_location_dispatch.py +++ b/python/sglang/srt/layers/moe/expert_location_dispatch.py @@ -4,7 +4,6 @@ import torch from sglang.srt.managers.expert_location import ExpertLocationMetadata from sglang.srt.managers.schedule_batch import global_server_args_dict -from sglang.srt.utils import get_compiler_backend @dataclass From 1f0a28fe45efb2cd2fb0c92d5ca9de0562b9eb44 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:29:15 +0800 Subject: [PATCH 0797/1089] more --- .../srt/layers/moe/expert_location_dispatch.py | 6 ++++-- python/sglang/srt/models/deepseek_v2.py | 13 +++++-------- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/python/sglang/srt/layers/moe/expert_location_dispatch.py b/python/sglang/srt/layers/moe/expert_location_dispatch.py index 4662e5f975f..822cda02204 100644 --- a/python/sglang/srt/layers/moe/expert_location_dispatch.py +++ b/python/sglang/srt/layers/moe/expert_location_dispatch.py @@ -3,7 +3,7 @@ import torch from sglang.srt.managers.expert_location import ExpertLocationMetadata -from sglang.srt.managers.schedule_batch import global_server_args_dict +from sglang.srt.managers.schedule_batch import global_server_args_dict, get_global_expert_location_metadata @dataclass @@ -14,8 +14,10 @@ class ExpertLocationDispatchInfo: partial_logical_to_all_physical_map_num_valid: torch.Tensor @classmethod - def init_new(cls, expert_location_metadata: ExpertLocationMetadata, ep_rank: int, layer_id: int): + def init_new(cls, ep_rank: int, layer_id: int): ep_dispatch_algorithm = global_server_args_dict["ep_dispatch_algorithm"] + expert_location_metadata = get_global_expert_location_metadata() + if ep_dispatch_algorithm is None: return None diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index e97ff9367c1..1e9e9a47408 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -49,6 +49,7 @@ ) from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE +from sglang.srt.layers.moe.expert_location_dispatch import ExpertLocationDispatchInfo from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.quantization.base_config import QuantizationConfig @@ -328,7 +329,6 @@ def forward_deepep( # router_logits: (num_tokens, n_experts) router_logits = self.gate(hidden_states) shared_output = self._forward_shared_experts(hidden_states) - expert_location_metadata = get_global_expert_location_metadata() topk_weights, topk_idx = select_experts( hidden_states=hidden_states, router_logits=router_logits, @@ -338,13 +338,10 @@ def forward_deepep( topk_group=self.topk_group, num_expert_group=self.num_expert_group, correction_bias=self.correction_bias, - forward_mode=forward_mode, - expert_logical_to_rank_dispatch_physical_map=expert_location_metadata.logical_to_rank_dispatch_physical_map[ - self.tp_rank, self.layer_id, :], - expert_logical_to_all_physical_map=expert_location_metadata.logical_to_all_physical_map[self.layer_id, - :], - expert_logical_to_all_physical_map_num_valid=expert_location_metadata.logical_to_all_physical_map_num_valid[ - self.layer_id, :], + expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new( + ep_rank=self.tp_rank, + layer_id=self.layer_id, + ), ) # print(f"hi [{get_tensor_model_parallel_rank()}, {self.__class__.__name__}] forward_deepep after-select_experts " # f"{self.layer_id=} {topk_weights=} {topk_idx=} ") From 75b9e04d9b55b3f3564203f03a8e30ef50e94427 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:29:35 +0800 Subject: [PATCH 0798/1089] more --- python/sglang/srt/layers/moe/ep_moe/layer.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 9d485014a40..d049d36df69 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -3,7 +3,7 @@ from typing import Callable, List, Optional, Tuple import torch - +from sglang.srt.layers.moe.expert_location_dispatch import ExpertLocationDispatchInfo from sglang.srt.managers.schedule_batch import get_global_expert_location_metadata try: @@ -236,9 +236,10 @@ def forward( num_expert_group=self.num_expert_group, correction_bias=self.correction_bias, custom_routing_function=self.custom_routing_function, - expert_logical_to_rank_dispatch_physical_map=get_global_expert_location_metadata().logical_to_rank_dispatch_physical_map[ - self.tp_rank, self.layer_id, : - ], + expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new( + ep_rank=self.tp_rank, + layer_id=self.layer_id, + ), ) reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess( @@ -276,7 +277,7 @@ def forward( BLOCK_SIZE=512, ) - seg_indptr_cur_rank = seg_indptr[self.start_expert_id : self.end_expert_id + 2] + seg_indptr_cur_rank = seg_indptr[self.start_expert_id: self.end_expert_id + 2] weight_indices_cur_rank = torch.arange( 0, self.num_experts_per_partition, @@ -479,7 +480,7 @@ def _weight_loader_physical( elif shard_id == "w1": param.data[expert_id][: self.intermediate_size, :] = loaded_weight elif shard_id == "w3": - param.data[expert_id][self.intermediate_size :, :] = loaded_weight + param.data[expert_id][self.intermediate_size:, :] = loaded_weight else: raise ValueError(f"Expected shard_id w1,w2 or w3 but got {shard_id}") @@ -511,11 +512,11 @@ def _load_fp8_scale( block_n, block_k = self.block_shape[0], self.block_shape[1] if shard_id == "w1": param_data[expert_id][ - : (self.intermediate_size + block_n - 1) // block_n, : + : (self.intermediate_size + block_n - 1) // block_n, : ] = loaded_weight elif shard_id == "w3": param_data[expert_id][ - (self.intermediate_size + block_n - 1) // block_n :, : + (self.intermediate_size + block_n - 1) // block_n:, : ] = loaded_weight else: # w2 param_data[expert_id] = loaded_weight From c5712e6f8dd78b918d00bc5b878816b9ca735ee1 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:29:41 +0800 Subject: [PATCH 0799/1089] more --- python/sglang/srt/layers/moe/expert_location_dispatch.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/sglang/srt/layers/moe/expert_location_dispatch.py b/python/sglang/srt/layers/moe/expert_location_dispatch.py index 822cda02204..78b59438b31 100644 --- a/python/sglang/srt/layers/moe/expert_location_dispatch.py +++ b/python/sglang/srt/layers/moe/expert_location_dispatch.py @@ -2,7 +2,6 @@ from typing import Literal, Optional import torch -from sglang.srt.managers.expert_location import ExpertLocationMetadata from sglang.srt.managers.schedule_batch import global_server_args_dict, get_global_expert_location_metadata From 7d3f6c36304916eb07e6191e0868e7c7d2fccb15 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:31:09 +0800 Subject: [PATCH 0800/1089] more --- python/sglang/srt/managers/expert_location.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index d118616dd4c..eb4765a482e 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -263,9 +263,7 @@ def _compute_logical_to_rank_dispatch_physical_map( logical_to_all_physical_map: torch.Tensor, num_gpus: int, ): - # TODO maybe improve this algorithm (e.g. ensure it is really balanced) # This is rarely called, so we use for loops for maximum clarity - r = random.Random() num_layers, num_logical_experts, _ = logical_to_all_physical_map.shape @@ -276,13 +274,16 @@ def _compute_logical_to_rank_dispatch_physical_map( for layer_id in range(num_layers): for logical_expert_id in range(num_logical_experts): - for gpu_id in range(num_gpus): - candidate_values = ExpertLocationMetadata.logical_to_all_physical_raw( - logical_to_all_physical_map, layer_id, logical_expert_id - ) - logical_to_rank_dispatch_physical_map[ - gpu_id, layer_id, logical_expert_id - ] = r.choice(candidate_values) + TODO + + # TODO old + # for gpu_id in range(num_gpus): + # candidate_values = ExpertLocationMetadata.logical_to_all_physical_raw( + # logical_to_all_physical_map, layer_id, logical_expert_id + # ) + # logical_to_rank_dispatch_physical_map[ + # gpu_id, layer_id, logical_expert_id + # ] = r.choice(candidate_values) return logical_to_rank_dispatch_physical_map From 6ca89d09d0f6b85e9846da76d782996af3108e52 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:31:21 +0800 Subject: [PATCH 0801/1089] more --- python/sglang/srt/managers/expert_location.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index eb4765a482e..9625c3bc3b6 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -267,8 +267,9 @@ def _compute_logical_to_rank_dispatch_physical_map( r = random.Random() num_layers, num_logical_experts, _ = logical_to_all_physical_map.shape - logical_to_rank_dispatch_physical_map = torch.zeros( - (num_gpus, num_layers, num_logical_experts), + logical_to_rank_dispatch_physical_map = torch.full( + size=(num_gpus, num_layers, num_logical_experts), + fill_value=-1, dtype=logical_to_all_physical_map.dtype, ) From 16bb3356ba8962683633a4df79def05ae919f878 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:31:51 +0800 Subject: [PATCH 0802/1089] more --- python/sglang/srt/managers/expert_location.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 9625c3bc3b6..579a0be41d6 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -275,6 +275,7 @@ def _compute_logical_to_rank_dispatch_physical_map( for layer_id in range(num_layers): for logical_expert_id in range(num_logical_experts): + partial_map = logical_to_all_physical_map[:, layer_id, logical_expert_id] TODO # TODO old @@ -286,6 +287,7 @@ def _compute_logical_to_rank_dispatch_physical_map( # gpu_id, layer_id, logical_expert_id # ] = r.choice(candidate_values) + assert torch.all(logical_to_rank_dispatch_physical_map != -1) return logical_to_rank_dispatch_physical_map From 1b755ea2d2c231b0946cce2b64c881023b669f95 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:32:23 +0800 Subject: [PATCH 0803/1089] more --- python/sglang/srt/managers/expert_location.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 579a0be41d6..32b725c947c 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -276,7 +276,10 @@ def _compute_logical_to_rank_dispatch_physical_map( for layer_id in range(num_layers): for logical_expert_id in range(num_logical_experts): partial_map = logical_to_all_physical_map[:, layer_id, logical_expert_id] - TODO + + for gpu_id in range(num_gpus): + if TODO: + partial_map[gpu_id] = TODO # TODO old # for gpu_id in range(num_gpus): @@ -291,6 +294,10 @@ def _compute_logical_to_rank_dispatch_physical_map( return logical_to_rank_dispatch_physical_map +def _compute_gpu_id_of_physical_expert(physical_expert_id: int) -> int: + return TODO + + @dataclass class ModelConfigForExpertLocation: num_layers: int From 22ca7eaad99d849407ba7c2dfa4e0dc902476b59 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:33:08 +0800 Subject: [PATCH 0804/1089] more --- python/sglang/srt/managers/expert_location.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 32b725c947c..e92d008a437 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -275,17 +275,22 @@ def _compute_logical_to_rank_dispatch_physical_map( for layer_id in range(num_layers): for logical_expert_id in range(num_logical_experts): + candidate_physical_expert_ids = ExpertLocationMetadata.logical_to_all_physical_raw( + logical_to_all_physical_map, layer_id, logical_expert_id + ) partial_map = logical_to_all_physical_map[:, layer_id, logical_expert_id] for gpu_id in range(num_gpus): + same_gpu_physical_expert_ids = [ + physical_expert_id + for physical_expert_id in candidate_physical_expert_ids + if _compute_gpu_id_of_physical_expert(physical_expert_id) == gpu_id + ] if TODO: partial_map[gpu_id] = TODO # TODO old # for gpu_id in range(num_gpus): - # candidate_values = ExpertLocationMetadata.logical_to_all_physical_raw( - # logical_to_all_physical_map, layer_id, logical_expert_id - # ) # logical_to_rank_dispatch_physical_map[ # gpu_id, layer_id, logical_expert_id # ] = r.choice(candidate_values) From a0be77b44ca8f5e61a09b00b933b83d91eab506e Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:33:20 +0800 Subject: [PATCH 0805/1089] more --- python/sglang/srt/managers/expert_location.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index e92d008a437..78a728b7418 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -286,8 +286,8 @@ def _compute_logical_to_rank_dispatch_physical_map( for physical_expert_id in candidate_physical_expert_ids if _compute_gpu_id_of_physical_expert(physical_expert_id) == gpu_id ] - if TODO: - partial_map[gpu_id] = TODO + if len(same_gpu_physical_expert_ids) > 0: + partial_map[gpu_id] = same_gpu_physical_expert_ids[0] # TODO old # for gpu_id in range(num_gpus): From 54adc088cba3a5ab26f6afe0c70d2c83629edd4e Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:34:16 +0800 Subject: [PATCH 0806/1089] more --- python/sglang/srt/managers/expert_location.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 78a728b7418..9c86d2dad3a 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -160,6 +160,7 @@ def _init_raw( logical_to_rank_dispatch_physical_map=_compute_logical_to_rank_dispatch_physical_map( logical_to_all_physical_map, num_gpus=ep_size, + num_physical_experts=num_physical_experts, ), ) @@ -262,11 +263,14 @@ def _pad_nested_array(arr, pad_value): def _compute_logical_to_rank_dispatch_physical_map( logical_to_all_physical_map: torch.Tensor, num_gpus: int, + num_physical_experts: int, ): # This is rarely called, so we use for loops for maximum clarity r = random.Random() + num_local_physical_experts = num_physical_experts // num_gpus num_layers, num_logical_experts, _ = logical_to_all_physical_map.shape + logical_to_rank_dispatch_physical_map = torch.full( size=(num_gpus, num_layers, num_logical_experts), fill_value=-1, @@ -284,7 +288,7 @@ def _compute_logical_to_rank_dispatch_physical_map( same_gpu_physical_expert_ids = [ physical_expert_id for physical_expert_id in candidate_physical_expert_ids - if _compute_gpu_id_of_physical_expert(physical_expert_id) == gpu_id + if _compute_gpu_id_of_physical_expert(physical_expert_id, num_local_physical_experts) == gpu_id ] if len(same_gpu_physical_expert_ids) > 0: partial_map[gpu_id] = same_gpu_physical_expert_ids[0] @@ -299,8 +303,8 @@ def _compute_logical_to_rank_dispatch_physical_map( return logical_to_rank_dispatch_physical_map -def _compute_gpu_id_of_physical_expert(physical_expert_id: int) -> int: - return TODO +def _compute_gpu_id_of_physical_expert(physical_expert_id: int, num_local_physical_experts: int) -> int: + return physical_expert_id // num_local_physical_experts @dataclass From c2f808a10289a6ebd730b5f502e3cfbab3a115af Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:35:01 +0800 Subject: [PATCH 0807/1089] more --- python/sglang/srt/managers/expert_location.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 9c86d2dad3a..0111c6994af 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -293,11 +293,9 @@ def _compute_logical_to_rank_dispatch_physical_map( if len(same_gpu_physical_expert_ids) > 0: partial_map[gpu_id] = same_gpu_physical_expert_ids[0] - # TODO old - # for gpu_id in range(num_gpus): - # logical_to_rank_dispatch_physical_map[ - # gpu_id, layer_id, logical_expert_id - # ] = r.choice(candidate_values) + num_remain = torch.sum(partial_map == -1).item() + partial_map[partial_map == -1] = torch.tensor( + _fair_choices(candidate_physical_expert_ids, k=num_remain, r=r)) assert torch.all(logical_to_rank_dispatch_physical_map != -1) return logical_to_rank_dispatch_physical_map From 038b83e6a8d4f3d70afde8918d7fc5dbb8d063ef Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:35:26 +0800 Subject: [PATCH 0808/1089] more --- python/sglang/srt/managers/expert_location.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 0111c6994af..b2e29dc6551 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -305,6 +305,10 @@ def _compute_gpu_id_of_physical_expert(physical_expert_id: int, num_local_physic return physical_expert_id // num_local_physical_experts +def _fair_choices(arr: List, k: int, r: random.Random) -> List: + return TODO + + @dataclass class ModelConfigForExpertLocation: num_layers: int From db44e12370ebd15fc6781b48869a33028feafe29 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:35:51 +0800 Subject: [PATCH 0809/1089] more --- python/sglang/srt/managers/expert_location.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index b2e29dc6551..585c5c112e0 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -306,7 +306,10 @@ def _compute_gpu_id_of_physical_expert(physical_expert_id: int, num_local_physic def _fair_choices(arr: List, k: int, r: random.Random) -> List: - return TODO + quotient, remainder = divmod(k, len(arr)) + ans = arr * quotient + r.sample(arr, k=remainder) + r.shuffle(ans) + return ans @dataclass From 8f81f8272c1e3cc89a43817293300e8facbeeabb Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:36:25 +0800 Subject: [PATCH 0810/1089] more --- python/sglang/srt/managers/expert_location.py | 4 ++-- test/srt/test_eplb.py | 16 +++++++++------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 585c5c112e0..092545c4802 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -157,7 +157,7 @@ def _init_raw( physical_to_logical_map=physical_to_logical_map, logical_to_all_physical_map=logical_to_all_physical_map_padded, logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid, - logical_to_rank_dispatch_physical_map=_compute_logical_to_rank_dispatch_physical_map( + logical_to_rank_dispatch_physical_map=compute_logical_to_rank_dispatch_physical_map( logical_to_all_physical_map, num_gpus=ep_size, num_physical_experts=num_physical_experts, @@ -260,7 +260,7 @@ def _pad_nested_array(arr, pad_value): return padded -def _compute_logical_to_rank_dispatch_physical_map( +def compute_logical_to_rank_dispatch_physical_map( logical_to_all_physical_map: torch.Tensor, num_gpus: int, num_physical_experts: int, diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index cfe3ef8c320..5a9d70f2194 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -3,9 +3,8 @@ import unittest from pathlib import Path -import torch - import sglang as sgl +import torch from sglang.srt.managers.expert_distribution_storage import ExpertDistributionStorage from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( @@ -144,11 +143,11 @@ def test_nontrivial_location(self): offset = 3 physical_to_logical_map = ( - offset - + torch.arange(0, _NUM_ROUTED_EXPERTS + ep_num_redundant_experts).repeat( - _NUM_HIDDEN_LAYERS, 1 - ) - ) % _NUM_ROUTED_EXPERTS + offset + + torch.arange(0, _NUM_ROUTED_EXPERTS + ep_num_redundant_experts).repeat( + _NUM_HIDDEN_LAYERS, 1 + ) + ) % _NUM_ROUTED_EXPERTS init_expert_location = dict( physical_to_logical_map=physical_to_logical_map.tolist() ) @@ -219,6 +218,9 @@ def _engine_flush_cache(self, engine: sgl.Engine): ret = engine.flush_cache() assert ret.success + def test_compute_logical_to_rank_dispatch_physical_map(self): + TODO + def _compute_trivial_expert_locations(ep_num_redundant_experts: int): return list( From 4ea5025c3550407f5ddff517495215e2f5195846 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:37:13 +0800 Subject: [PATCH 0811/1089] more --- test/srt/test_eplb.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index 5a9d70f2194..6d8feb94362 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -5,6 +5,7 @@ import sglang as sgl import torch +from python.sglang.srt.managers.expert_location import compute_logical_to_rank_dispatch_physical_map from sglang.srt.managers.expert_distribution_storage import ExpertDistributionStorage from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( @@ -143,8 +144,8 @@ def test_nontrivial_location(self): offset = 3 physical_to_logical_map = ( - offset - + torch.arange(0, _NUM_ROUTED_EXPERTS + ep_num_redundant_experts).repeat( + offset + + torch.arange(0, _NUM_ROUTED_EXPERTS + ep_num_redundant_experts).repeat( _NUM_HIDDEN_LAYERS, 1 ) ) % _NUM_ROUTED_EXPERTS @@ -219,7 +220,14 @@ def _engine_flush_cache(self, engine: sgl.Engine): assert ret.success def test_compute_logical_to_rank_dispatch_physical_map(self): - TODO + expect = torch.tensor([[]]) # TODO + actual = compute_logical_to_rank_dispatch_physical_map( + logical_to_all_physical_map=logical_to_all_physical_map, + num_gpus=TODO, + num_physical_experts=TODO, + ) + print(f"{actual=} {expect=}") + self.assertEqual(actual.tolist(), expect.tolist()) def _compute_trivial_expert_locations(ep_num_redundant_experts: int): From 231a7e4ac73a6d6c86a8b94ed8c8ad0f5739e2a1 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:38:13 +0800 Subject: [PATCH 0812/1089] more --- test/srt/test_eplb.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index 6d8feb94362..df16c250570 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -220,11 +220,16 @@ def _engine_flush_cache(self, engine: sgl.Engine): assert ret.success def test_compute_logical_to_rank_dispatch_physical_map(self): + # 8 logical expert + logical_to_all_physical_map = torch.tensor([ + [[0], [1], [2], [3], [4], [5], [6], [7]], + ]) expect = torch.tensor([[]]) # TODO + actual = compute_logical_to_rank_dispatch_physical_map( logical_to_all_physical_map=logical_to_all_physical_map, - num_gpus=TODO, - num_physical_experts=TODO, + num_gpus=4, + num_physical_experts=12, ) print(f"{actual=} {expect=}") self.assertEqual(actual.tolist(), expect.tolist()) From 16ac1a1de974723006b8e6db47e3bff2b6dba900 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:38:58 +0800 Subject: [PATCH 0813/1089] more --- test/srt/test_eplb.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index df16c250570..6a78c6d7df6 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -221,18 +221,19 @@ def _engine_flush_cache(self, engine: sgl.Engine): def test_compute_logical_to_rank_dispatch_physical_map(self): # 8 logical expert - logical_to_all_physical_map = torch.tensor([ - [[0], [1], [2], [3], [4], [5], [6], [7]], - ]) - expect = torch.tensor([[]]) # TODO - - actual = compute_logical_to_rank_dispatch_physical_map( - logical_to_all_physical_map=logical_to_all_physical_map, - num_gpus=4, - num_physical_experts=12, - ) - print(f"{actual=} {expect=}") - self.assertEqual(actual.tolist(), expect.tolist()) + for logical_to_all_physical_map, expect_output in [ + ( + torch.tensor([[[0], [1], [2], [3], [4], [5], [6], [7]], ]), + torch.tensor([[]]), + ), + ]: + actual_output = compute_logical_to_rank_dispatch_physical_map( + logical_to_all_physical_map=logical_to_all_physical_map, + num_gpus=4, + num_physical_experts=12, + ) + print(f"{actual_output=} {expect_output=}") + self.assertEqual(actual_output.tolist(), expect_output.tolist()) def _compute_trivial_expert_locations(ep_num_redundant_experts: int): From e9cab0d9cfc57e334a1837e50771d4b420073179 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:39:15 +0800 Subject: [PATCH 0814/1089] more --- test/srt/test_eplb.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index 6a78c6d7df6..e09a119fad2 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -223,17 +223,17 @@ def test_compute_logical_to_rank_dispatch_physical_map(self): # 8 logical expert for logical_to_all_physical_map, expect_output in [ ( - torch.tensor([[[0], [1], [2], [3], [4], [5], [6], [7]], ]), - torch.tensor([[]]), + [[[0], [1], [2], [3], [4], [5], [6], [7]]], + [[]], ), ]: actual_output = compute_logical_to_rank_dispatch_physical_map( - logical_to_all_physical_map=logical_to_all_physical_map, + logical_to_all_physical_map=torch.tensor(logical_to_all_physical_map), num_gpus=4, num_physical_experts=12, ) print(f"{actual_output=} {expect_output=}") - self.assertEqual(actual_output.tolist(), expect_output.tolist()) + self.assertEqual(actual_output.tolist(), expect_output) def _compute_trivial_expert_locations(ep_num_redundant_experts: int): From a1f9601be19cb4c338296f17d1f5a8ad4c9ba539 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:40:37 +0800 Subject: [PATCH 0815/1089] more --- test/srt/test_eplb.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index e09a119fad2..677014bb4e5 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -222,9 +222,15 @@ def _engine_flush_cache(self, engine: sgl.Engine): def test_compute_logical_to_rank_dispatch_physical_map(self): # 8 logical expert for logical_to_all_physical_map, expect_output in [ + # Identity map ( [[[0], [1], [2], [3], [4], [5], [6], [7]]], - [[]], + [[]], # TODO + ), + # Identity map + consider redundant experts + ( + [[[0, 8], [1, 9], [2, 10], [3, 11], [4], [5], [6], [7]]], + [[]], # TODO ), ]: actual_output = compute_logical_to_rank_dispatch_physical_map( From 66c2dcaddbf764b67d8609911e1dc2b4d0ad2030 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:41:55 +0800 Subject: [PATCH 0816/1089] more --- test/srt/test_eplb.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index 677014bb4e5..3c31d515e45 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -232,6 +232,16 @@ def test_compute_logical_to_rank_dispatch_physical_map(self): [[[0, 8], [1, 9], [2, 10], [3, 11], [4], [5], [6], [7]]], [[]], # TODO ), + # One logical expert is put on ALL gpus + ( + [[[0, 3, 6, 9], [1], [2], [4], [5], [7], [8], [10]]], + [[]], # TODO + ), + # One logical expert is put multiple times on ONE gpu + ( + [[[0, 1, 2], [3], [4], [5], [6], [7], [8], [9]]], + [[]], # TODO + ), ]: actual_output = compute_logical_to_rank_dispatch_physical_map( logical_to_all_physical_map=torch.tensor(logical_to_all_physical_map), From 74f6b5bfbc2f979446ed390e9f63b25dc8d2ad0c Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:42:32 +0800 Subject: [PATCH 0817/1089] more --- test/srt/test_eplb.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index 3c31d515e45..90e856c297c 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -229,17 +229,19 @@ def test_compute_logical_to_rank_dispatch_physical_map(self): ), # Identity map + consider redundant experts ( - [[[0, 8], [1, 9], [2, 10], [3, 11], [4], [5], [6], [7]]], + [[[0, 8], [1, 9], [2, 10], [3, 11], [4, -1], [5, -1], [6, -1], [7, -1]]], [[]], # TODO ), # One logical expert is put on ALL gpus ( - [[[0, 3, 6, 9], [1], [2], [4], [5], [7], [8], [10]]], + [[[0, 3, 6, 9], [1, -1, -1, -1], [2, -1, -1, -1], [4, -1, -1, -1], [5, -1, -1, -1], [7, -1, -1, -1], + [8, -1, -1, -1], [10, -1, -1, -1]]], [[]], # TODO ), # One logical expert is put multiple times on ONE gpu ( - [[[0, 1, 2], [3], [4], [5], [6], [7], [8], [9]]], + [[[0, 1, 2], [3, -1, -1], [4, -1, -1], [5, -1, -1], [6, -1, -1], [7, -1, -1], [8, -1, -1], + [9, -1, -1]]], [[]], # TODO ), ]: From b55b38e5f31701361de7b416fba479a3da1d1ef1 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:43:39 +0800 Subject: [PATCH 0818/1089] more --- test/srt/test_eplb.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index 90e856c297c..58427a6a2cb 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -244,6 +244,12 @@ def test_compute_logical_to_rank_dispatch_physical_map(self): [9, -1, -1]]], [[]], # TODO ), + # Random + ( + [[[4, 11, -1], [5, 9, 0], [6, -1, -1], [8, -1, -1], [1, -1, -1], [10, -1, -1], [2, 3, -1], + [7, -1, -1]]], + [[]], # TODO + ), ]: actual_output = compute_logical_to_rank_dispatch_physical_map( logical_to_all_physical_map=torch.tensor(logical_to_all_physical_map), From cd6c57b49e8eb99fc6f6575c0f7d7a57a6314d35 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:44:19 +0800 Subject: [PATCH 0819/1089] more --- test/srt/test_eplb.py | 40 ------------------------- test/srt/test_expert_location.py | 51 ++++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 40 deletions(-) create mode 100644 test/srt/test_expert_location.py diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index 58427a6a2cb..3ea0a7bd1b9 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -219,46 +219,6 @@ def _engine_flush_cache(self, engine: sgl.Engine): ret = engine.flush_cache() assert ret.success - def test_compute_logical_to_rank_dispatch_physical_map(self): - # 8 logical expert - for logical_to_all_physical_map, expect_output in [ - # Identity map - ( - [[[0], [1], [2], [3], [4], [5], [6], [7]]], - [[]], # TODO - ), - # Identity map + consider redundant experts - ( - [[[0, 8], [1, 9], [2, 10], [3, 11], [4, -1], [5, -1], [6, -1], [7, -1]]], - [[]], # TODO - ), - # One logical expert is put on ALL gpus - ( - [[[0, 3, 6, 9], [1, -1, -1, -1], [2, -1, -1, -1], [4, -1, -1, -1], [5, -1, -1, -1], [7, -1, -1, -1], - [8, -1, -1, -1], [10, -1, -1, -1]]], - [[]], # TODO - ), - # One logical expert is put multiple times on ONE gpu - ( - [[[0, 1, 2], [3, -1, -1], [4, -1, -1], [5, -1, -1], [6, -1, -1], [7, -1, -1], [8, -1, -1], - [9, -1, -1]]], - [[]], # TODO - ), - # Random - ( - [[[4, 11, -1], [5, 9, 0], [6, -1, -1], [8, -1, -1], [1, -1, -1], [10, -1, -1], [2, 3, -1], - [7, -1, -1]]], - [[]], # TODO - ), - ]: - actual_output = compute_logical_to_rank_dispatch_physical_map( - logical_to_all_physical_map=torch.tensor(logical_to_all_physical_map), - num_gpus=4, - num_physical_experts=12, - ) - print(f"{actual_output=} {expect_output=}") - self.assertEqual(actual_output.tolist(), expect_output) - def _compute_trivial_expert_locations(ep_num_redundant_experts: int): return list( diff --git a/test/srt/test_expert_location.py b/test/srt/test_expert_location.py new file mode 100644 index 00000000000..1078b507aa3 --- /dev/null +++ b/test/srt/test_expert_location.py @@ -0,0 +1,51 @@ +import unittest + +import torch +from python.sglang.srt.managers.expert_location import compute_logical_to_rank_dispatch_physical_map +from python.sglang.test.test_utils import CustomTestCase + + +class TestExpertLocation(CustomTestCase): + def test_compute_logical_to_rank_dispatch_physical_map(self): + # 8 logical expert + for logical_to_all_physical_map, expect_output in [ + # Identity map + ( + [[[0], [1], [2], [3], [4], [5], [6], [7]]], + [[]], # TODO + ), + # Identity map + consider redundant experts + ( + [[[0, 8], [1, 9], [2, 10], [3, 11], [4, -1], [5, -1], [6, -1], [7, -1]]], + [[]], # TODO + ), + # One logical expert is put on ALL gpus + ( + [[[0, 3, 6, 9], [1, -1, -1, -1], [2, -1, -1, -1], [4, -1, -1, -1], [5, -1, -1, -1], [7, -1, -1, -1], + [8, -1, -1, -1], [10, -1, -1, -1]]], + [[]], # TODO + ), + # One logical expert is put multiple times on ONE gpu + ( + [[[0, 1, 2], [3, -1, -1], [4, -1, -1], [5, -1, -1], [6, -1, -1], [7, -1, -1], [8, -1, -1], + [9, -1, -1]]], + [[]], # TODO + ), + # Random + ( + [[[4, 11, -1], [5, 9, 0], [6, -1, -1], [8, -1, -1], [1, -1, -1], [10, -1, -1], [2, 3, -1], + [7, -1, -1]]], + [[]], # TODO + ), + ]: + actual_output = compute_logical_to_rank_dispatch_physical_map( + logical_to_all_physical_map=torch.tensor(logical_to_all_physical_map), + num_gpus=4, + num_physical_experts=12, + ) + print(f"{actual_output=} {expect_output=}") + self.assertEqual(actual_output.tolist(), expect_output) + + +if __name__ == "__main__": + unittest.main() From 823d6d03a596429fef58bde5437a8421218de4de Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:44:36 +0800 Subject: [PATCH 0820/1089] more --- test/srt/run_suite.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 3d0417b201d..ca2f4632525 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -31,6 +31,7 @@ class TestFile: TestFile("test_fa3.py", 5), TestFile("test_fp8_kernel.py", 8), TestFile("test_embedding_openai_server.py", 36), + TestFile("test_expert_location.py", 10), TestFile("test_hidden_states.py", 55), TestFile("test_int8_kernel.py", 8), TestFile("test_input_embeddings.py", 38), @@ -195,7 +196,7 @@ def auto_partition(files, rank, size): if args.auto_partition_size: files = auto_partition(files, args.auto_partition_id, args.auto_partition_size) else: - files = files[args.range_begin : args.range_end] + files = files[args.range_begin: args.range_end] print("The running tests are ", [f.name for f in files]) From 4fe1e1abc88414265a6715872b37829b1593b4b1 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:46:13 +0800 Subject: [PATCH 0821/1089] more --- python/sglang/srt/managers/eplb_simulator.py | 1 + python/sglang/srt/managers/expert_distribution_storage.py | 1 + 2 files changed, 2 insertions(+) diff --git a/python/sglang/srt/managers/eplb_simulator.py b/python/sglang/srt/managers/eplb_simulator.py index 1664dedc9f5..4244a3fb433 100644 --- a/python/sglang/srt/managers/eplb_simulator.py +++ b/python/sglang/srt/managers/eplb_simulator.py @@ -1,4 +1,5 @@ # TODO where to put this file? +# TODO add some doc import dataclasses import json from collections import defaultdict diff --git a/python/sglang/srt/managers/expert_distribution_storage.py b/python/sglang/srt/managers/expert_distribution_storage.py index ebc51daec7f..8a52ada46b0 100644 --- a/python/sglang/srt/managers/expert_distribution_storage.py +++ b/python/sglang/srt/managers/expert_distribution_storage.py @@ -28,6 +28,7 @@ async def save_current(self): logger.info(f"save_current to path {path}") path.write_text(json.dumps(data)) + # Most vanilla method since I do not have production environment data to test what algorithm is better def get_last_snapshot(self) -> Optional[Dict[str, Any]]: path = self.get_last_snapshot_path(self._dir_data) if path is None: From 0df2d6b961d7e96f47ed05a0a328fd73f89b7d0b Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:46:22 +0800 Subject: [PATCH 0822/1089] more --- .../srt/managers/expert_distribution.py | 74 +++++++++---------- 1 file changed, 37 insertions(+), 37 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index cc8f404d071..24fd8572c36 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -23,9 +23,9 @@ class ExpertDistributionRecorder: @staticmethod def init_new( - server_args: ServerArgs, - expert_location_metadata: "ExpertLocationMetadata", - rank: int, + server_args: ServerArgs, + expert_location_metadata: "ExpertLocationMetadata", + rank: int, ): if server_args.enable_expert_distribution_recorder: return _ExpertDistributionRecorderReal(server_args, expert_location_metadata, rank) @@ -73,10 +73,10 @@ class _ExpertDistributionRecorderNoop(ExpertDistributionRecorder): class _ExpertDistributionRecorderReal(ExpertDistributionRecorder): def __init__( - self, - server_args: ServerArgs, - expert_location_metadata: "ExpertLocationMetadata", - rank: int, + self, + server_args: ServerArgs, + expert_location_metadata: "ExpertLocationMetadata", + rank: int, ): self._server_args = server_args self._expert_location_metadata = expert_location_metadata @@ -145,7 +145,7 @@ def _reset(self): """Reset the expert distribution recorder.""" logger.info("Resetting ExpertDistributionRecorder...") assert ( - self._current_layer_idx.value is None + self._current_layer_idx.value is None ), f"{self._current_layer_idx.value=}" for gatherer in self._single_pass_gatherers.values(): gatherer.reset() @@ -192,7 +192,7 @@ def set_global_expert_distribution_recorder(value): def postprocess_dumps( - physical_dumps: List[Any], expert_location_metadata: "ExpertLocationMetadata" + physical_dumps: List[Any], expert_location_metadata: "ExpertLocationMetadata" ): return _Accumulator.get_class().postprocess_dumps( physical_dumps, expert_location_metadata @@ -205,7 +205,7 @@ def postprocess_dumps( class _SinglePassGatherer(ABC): @staticmethod def init_new( - server_args: ServerArgs, expert_location_metadata: "ExpertLocationMetadata" + server_args: ServerArgs, expert_location_metadata: "ExpertLocationMetadata" ) -> "_SinglePassGatherer": if server_args.enable_deepep_moe: # `auto` has many restrictions now, so we lower the priority to implement low-latency capturing for auto @@ -224,7 +224,7 @@ def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor): pass def on_deepep_dispatch_normal( - self, layer_idx: int, num_recv_tokens_per_expert_list: List[int] + self, layer_idx: int, num_recv_tokens_per_expert_list: List[int] ): pass @@ -244,7 +244,7 @@ def __init__(self, expert_location_metadata: "ExpertLocationMetadata"): self._num_recv_tokens_per_expert_list_of_layer = {} def _on_layer_data( - self, layer_idx: int, num_recv_tokens_per_expert_list: List[int] + self, layer_idx: int, num_recv_tokens_per_expert_list: List[int] ): # TODO for TBO, we may need to relax this restriction assert layer_idx not in self._num_recv_tokens_per_expert_list_of_layer @@ -288,7 +288,7 @@ def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor): class _DeepepNormalSinglePassGatherer(_LayerBasedSinglePassGatherer): def on_deepep_dispatch_normal( - self, layer_idx: int, num_recv_tokens_per_expert_list: List[int] + self, layer_idx: int, num_recv_tokens_per_expert_list: List[int] ): assert isinstance(num_recv_tokens_per_expert_list, list) self._on_layer_data(layer_idx, num_recv_tokens_per_expert_list) @@ -311,7 +311,7 @@ def reset(self): self._data[...] = 0 def collect(self) -> torch.Tensor: - return self._data.clone() + return self._data # --------------------------------------- Accumulator ----------------------------------------- @@ -322,7 +322,7 @@ def collect(self) -> torch.Tensor: class _Accumulator(ABC): @staticmethod def init_new( - expert_location_metadata: "ExpertLocationMetadata", rank: int + expert_location_metadata: "ExpertLocationMetadata", rank: int ) -> "_Accumulator": return _Accumulator.get_class()(expert_location_metadata, rank) @@ -344,17 +344,17 @@ def get_single_pass_gatherer_key(self, debug_name: Optional[str]): @classmethod def postprocess_dumps( - cls, - physical_dumps: List[Any], - expert_location_metadata: "ExpertLocationMetadata", + cls, + physical_dumps: List[Any], + expert_location_metadata: "ExpertLocationMetadata", ): raise NotImplementedError def append( - self, - forward_pass_id: int, - gatherer_key: str, - single_pass_physical_count: torch.Tensor, + self, + forward_pass_id: int, + gatherer_key: str, + single_pass_physical_count: torch.Tensor, ): raise NotImplementedError @@ -368,9 +368,9 @@ def dump(self): class _DetailAccumulator(_Accumulator): @classmethod def postprocess_dumps( - cls, - physical_dumps: List[Any], - expert_location_metadata: "ExpertLocationMetadata", + cls, + physical_dumps: List[Any], + expert_location_metadata: "ExpertLocationMetadata", ): # Do not convert to logical since we want all details return [record for physical_dump in physical_dumps for record in physical_dump] @@ -394,10 +394,10 @@ def get_single_pass_gatherer_key(self, debug_name: Optional[str]): return super().get_single_pass_gatherer_key(debug_name) def append( - self, - forward_pass_id: int, - gatherer_key: str, - single_pass_physical_count: torch.Tensor, + self, + forward_pass_id: int, + gatherer_key: str, + single_pass_physical_count: torch.Tensor, ): single_pass_physical_count = single_pass_physical_count.to("cpu") if self._save_dir is None: @@ -428,9 +428,9 @@ def dump(self): class _StatAccumulator(_Accumulator): @classmethod def postprocess_dumps( - cls, - physical_dumps: List[Any], - expert_location_metadata: "ExpertLocationMetadata", + cls, + physical_dumps: List[Any], + expert_location_metadata: "ExpertLocationMetadata", ): logical_count = torch.zeros( ( @@ -442,7 +442,7 @@ def postprocess_dumps( for physical_dump in physical_dumps: for layer_index in range(expert_location_metadata.num_layers): for local_physical_expert_index in range( - expert_location_metadata.num_local_physical_experts + expert_location_metadata.num_local_physical_experts ): global_physical_expert_index = ( expert_location_metadata.local_physical_to_physical( @@ -470,10 +470,10 @@ def __init__(self, expert_location_metadata: "ExpertLocationMetadata", rank: int ) def append( - self, - forward_pass_id: int, - gatherer_key: str, - single_pass_physical_count: torch.Tensor, + self, + forward_pass_id: int, + gatherer_key: str, + single_pass_physical_count: torch.Tensor, ): self._physical_count += single_pass_physical_count From 895768aa394558209dc8a2066883bd5096bfb108 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:46:45 +0800 Subject: [PATCH 0823/1089] more --- python/sglang/srt/managers/expert_distribution.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 24fd8572c36..16978bdd8ad 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -475,6 +475,7 @@ def append( gatherer_key: str, single_pass_physical_count: torch.Tensor, ): + # Can optimize if overhead here is large self._physical_count += single_pass_physical_count def reset(self): From dcde89af84c841db83f8faa14611863a4a7172dd Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:47:26 +0800 Subject: [PATCH 0824/1089] more --- python/sglang/srt/entrypoints/engine.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index a472509866a..83920237681 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -659,7 +659,6 @@ def _compute_initial_expert_location_metadata( data_dict = json.loads(Path(data).read_text()) if "physical_to_logical_map" in data_dict: - # TODO We may want to allow users to not provide `logical_to_all_physical_map` if this API is frequently used return ExpertLocationMetadata.init_by_mapping(server_args, **data_dict) elif "logical_count" in data_dict: return ExpertLocationMetadata.init_by_eplb(server_args, **data_dict) From b5c96e4856b2ee93e50dc352d0a9eff086079570 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:47:28 +0800 Subject: [PATCH 0825/1089] more --- python/sglang/srt/managers/scheduler_input_blocker.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/managers/scheduler_input_blocker.py b/python/sglang/srt/managers/scheduler_input_blocker.py index f55c6fe7372..6668d96b950 100644 --- a/python/sglang/srt/managers/scheduler_input_blocker.py +++ b/python/sglang/srt/managers/scheduler_input_blocker.py @@ -75,6 +75,7 @@ def _execute_unblock_req(self): def _compute_global_unblock_barrier(self): local_arrived = self._noop or (self._state == _State.GLOBAL_UNBLOCK_BARRIER) global_arrived = torch.tensor(local_arrived).cuda() + # Can optimize if bottleneck torch.distributed.all_reduce(global_arrived, torch.distributed.ReduceOp.MIN) return global_arrived.cpu().item() From a0d95f5bd45d1c46ad77239a311637ff4809a09a Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:48:20 +0800 Subject: [PATCH 0826/1089] more --- python/sglang/srt/entrypoints/engine.py | 166 ++++++++++++------------ 1 file changed, 85 insertions(+), 81 deletions(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 83920237681..5a839412cf1 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -137,33 +137,33 @@ def __init__(self, **kwargs): ) def generate( - self, - # The input prompt. It can be a single prompt or a batch of prompts. - prompt: Optional[Union[List[str], str]] = None, - sampling_params: Optional[Union[List[Dict], Dict]] = None, - # The token ids for text; one can either specify text or input_ids. - input_ids: Optional[Union[List[List[int]], List[int]]] = None, - # The image input. It can be an image instance, file name, URL, or base64 encoded string. - # Can be formatted as: - # - Single image for a single request - # - List of images (one per request in a batch) - # - List of lists of images (multiple images per request) - # See also python/sglang/srt/utils.py:load_image for more details. - image_data: Optional[ - Union[ - List[List[Union[Image, str]]], - List[Union[Image, str]], - Union[Image, str], - ] - ] = None, - return_logprob: Optional[Union[List[bool], bool]] = False, - logprob_start_len: Optional[Union[List[int], int]] = None, - top_logprobs_num: Optional[Union[List[int], int]] = None, - token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None, - lora_path: Optional[List[Optional[str]]] = None, - custom_logit_processor: Optional[Union[List[str], str]] = None, - return_hidden_states: bool = False, - stream: bool = False, + self, + # The input prompt. It can be a single prompt or a batch of prompts. + prompt: Optional[Union[List[str], str]] = None, + sampling_params: Optional[Union[List[Dict], Dict]] = None, + # The token ids for text; one can either specify text or input_ids. + input_ids: Optional[Union[List[List[int]], List[int]]] = None, + # The image input. It can be an image instance, file name, URL, or base64 encoded string. + # Can be formatted as: + # - Single image for a single request + # - List of images (one per request in a batch) + # - List of lists of images (multiple images per request) + # See also python/sglang/srt/utils.py:load_image for more details. + image_data: Optional[ + Union[ + List[List[Union[Image, str]]], + List[Union[Image, str]], + Union[Image, str], + ] + ] = None, + return_logprob: Optional[Union[List[bool], bool]] = False, + logprob_start_len: Optional[Union[List[int], int]] = None, + top_logprobs_num: Optional[Union[List[int], int]] = None, + token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None, + lora_path: Optional[List[Optional[str]]] = None, + custom_logit_processor: Optional[Union[List[str], str]] = None, + return_hidden_states: bool = False, + stream: bool = False, ) -> Union[Dict, Iterator[Dict]]: """ The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`. @@ -202,32 +202,32 @@ def generator_wrapper(): return ret async def async_generate( - self, - # The input prompt. It can be a single prompt or a batch of prompts. - prompt: Optional[Union[List[str], str]] = None, - sampling_params: Optional[Union[List[Dict], Dict]] = None, - # The token ids for text; one can either specify text or input_ids. - input_ids: Optional[Union[List[List[int]], List[int]]] = None, - # The image input. It can be an image instance, file name, URL, or base64 encoded string. - # Can be formatted as: - # - Single image for a single request - # - List of images (one per request in a batch) - # - List of lists of images (multiple images per request) - # See also python/sglang/srt/utils.py:load_image for more details. - image_data: Optional[ - Union[ - List[List[Union[Image, str]]], - List[Union[Image, str]], - Union[Image, str], - ] - ] = None, - return_logprob: Optional[Union[List[bool], bool]] = False, - logprob_start_len: Optional[Union[List[int], int]] = None, - top_logprobs_num: Optional[Union[List[int], int]] = None, - token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None, - lora_path: Optional[List[Optional[str]]] = None, - custom_logit_processor: Optional[Union[List[str], str]] = None, - stream: bool = False, + self, + # The input prompt. It can be a single prompt or a batch of prompts. + prompt: Optional[Union[List[str], str]] = None, + sampling_params: Optional[Union[List[Dict], Dict]] = None, + # The token ids for text; one can either specify text or input_ids. + input_ids: Optional[Union[List[List[int]], List[int]]] = None, + # The image input. It can be an image instance, file name, URL, or base64 encoded string. + # Can be formatted as: + # - Single image for a single request + # - List of images (one per request in a batch) + # - List of lists of images (multiple images per request) + # See also python/sglang/srt/utils.py:load_image for more details. + image_data: Optional[ + Union[ + List[List[Union[Image, str]]], + List[Union[Image, str]], + Union[Image, str], + ] + ] = None, + return_logprob: Optional[Union[List[bool], bool]] = False, + logprob_start_len: Optional[Union[List[int], int]] = None, + top_logprobs_num: Optional[Union[List[int], int]] = None, + token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None, + lora_path: Optional[List[Optional[str]]] = None, + custom_logit_processor: Optional[Union[List[str], str]] = None, + stream: bool = False, ) -> Union[Dict, AsyncIterator[Dict]]: """ The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`. @@ -254,15 +254,15 @@ async def async_generate( return await generator.__anext__() def encode( - self, - prompt: Union[str, List[str], List[Dict], List[List[Dict]]], - image_data: Optional[ - Union[ - List[List[Union[Image, str]]], - List[Union[Image, str]], - Union[Image, str], - ] - ] = None, + self, + prompt: Union[str, List[str], List[Dict], List[List[Dict]]], + image_data: Optional[ + Union[ + List[List[Union[Image, str]]], + List[Union[Image, str]], + Union[Image, str], + ] + ] = None, ) -> Dict: """ The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`. @@ -310,13 +310,13 @@ def get_server_info(self): } def init_weights_update_group( - self, - master_address: str, - master_port: int, - rank_offset: int, - world_size: int, - group_name: str, - backend: str = "nccl", + self, + master_address: str, + master_port: int, + rank_offset: int, + world_size: int, + group_name: str, + backend: str = "nccl", ): """Initialize parameter update group.""" obj = InitWeightsUpdateGroupReqInput( @@ -345,10 +345,10 @@ def update_weights_from_distributed(self, name: str, dtype, shape): ) def update_weights_from_tensor( - self, - named_tensors: List[Tuple[str, torch.Tensor]], - load_format: Optional[str] = None, - flush_cache: bool = True, + self, + named_tensors: List[Tuple[str, torch.Tensor]], + load_format: Optional[str] = None, + flush_cache: bool = True, ): """Update weights from distributed source. If there are going to be more updates, set `flush_cache` to be true to avoid duplicated operations such as clearing cache.""" @@ -385,10 +385,10 @@ def update_expert_location(self, expert_location_metadata: ExpertLocationMetadat ) def update_weights_from_disk( - self, - model_path: str, - load_format: Optional[str] = None, - param_categories: Optional[List[str]] = None, + self, + model_path: str, + load_format: Optional[str] = None, + param_categories: Optional[List[str]] = None, ): """Update the weights from disk inplace without re-launching the engine. @@ -507,7 +507,7 @@ def sigquit_handler(signum, frame): def _launch_subprocesses( - server_args: ServerArgs, port_args: Optional[PortArgs] = None + server_args: ServerArgs, port_args: Optional[PortArgs] = None ) -> Tuple[TokenizerManager, Dict]: """ Launch the TokenizerManager in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess. @@ -551,8 +551,8 @@ def _launch_subprocesses( for tp_rank in tp_rank_range: reader, writer = mp.Pipe(duplex=False) gpu_id = ( - server_args.base_gpu_id - + (tp_rank % tp_size_per_node) * server_args.gpu_id_step + server_args.base_gpu_id + + (tp_rank % tp_size_per_node) * server_args.gpu_id_step ) proc = mp.Process( target=run_scheduler_process, @@ -650,7 +650,7 @@ def _launch_subprocesses( def _compute_initial_expert_location_metadata( - server_args: ServerArgs, eplb_manager: EPLBManager + server_args: ServerArgs, eplb_manager: EPLBManager ) -> ExpertLocationMetadata: if (data := server_args.init_expert_location) is not None: try: @@ -659,13 +659,17 @@ def _compute_initial_expert_location_metadata( data_dict = json.loads(Path(data).read_text()) if "physical_to_logical_map" in data_dict: + logger.info("init_expert_location from init_by_mapping using ServerArgs.init_expert_location") return ExpertLocationMetadata.init_by_mapping(server_args, **data_dict) elif "logical_count" in data_dict: + logger.info("init_expert_location from init_by_eplb using ServerArgs.init_expert_location") return ExpertLocationMetadata.init_by_eplb(server_args, **data_dict) else: raise NotImplementedError( f"Unknown init_expert_location format ({list(data_dict.keys())=})" ) if server_args.enable_eplb: + logger.info("init_expert_location from EPLBManager") return eplb_manager.compute_expert_location_metadata() + return ExpertLocationMetadata.init_trivial(server_args) From 7295b5002669cd7022852b23e3683c650dac684d Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:58:22 +0800 Subject: [PATCH 0827/1089] more --- python/sglang/srt/managers/expert_distribution.py | 3 +++ python/sglang/srt/model_executor/model_runner.py | 2 ++ 2 files changed, 5 insertions(+) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 16978bdd8ad..f8fc78a2416 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -119,6 +119,9 @@ def _on_forward_pass_end(self, forward_pass_id: int): forward_pass_id, gatherer_key, single_pass_physical_count ) + def flush_buffer_depending_on_expert_location_metadata(self): + TODO + def on_select_experts(self, topk_ids: torch.Tensor): self._on_hook("on_select_experts", topk_ids=topk_ids) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index f99e32d3242..292c2092b4c 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -496,6 +496,8 @@ def update_expert_location(self, recv_req: UpdateExpertLocationReqInput): logger.info("update_expert_location start") torch.distributed.barrier() + get_global_expert_distribution_recorder().flush_buffer_depending_on_expert_location_metadata() + get_global_expert_location_metadata().update(recv_req.expert_location_metadata) if self.tp_rank == 0 and get_bool_env_var( "SGLANG_LOG_EXPERT_LOCATION_METADATA" From e1e260dcbc7341421552a21996814b79380da446 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:59:23 +0800 Subject: [PATCH 0828/1089] more --- python/sglang/srt/managers/expert_distribution.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index f8fc78a2416..02a4df47402 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -120,7 +120,7 @@ def _on_forward_pass_end(self, forward_pass_id: int): ) def flush_buffer_depending_on_expert_location_metadata(self): - TODO + self._accumulator.flush_buffer_depending_on_expert_location_metadata() def on_select_experts(self, topk_ids: torch.Tensor): self._on_hook("on_select_experts", topk_ids=topk_ids) @@ -367,6 +367,9 @@ def reset(self): def dump(self): raise NotImplementedError + def flush_buffer_depending_on_expert_location_metadata(self): + raise NotImplementedError + class _DetailAccumulator(_Accumulator): @classmethod @@ -427,6 +430,9 @@ def dump(self): torch.save(self._records, str(path_output)) return [dict(path_output=str(path_output))] + def flush_buffer_depending_on_expert_location_metadata(self): + pass + class _StatAccumulator(_Accumulator): @classmethod @@ -489,3 +495,6 @@ def dump(self): rank=self._rank, physical_count=self._physical_count.tolist(), ) + + def flush_buffer_depending_on_expert_location_metadata(self): + TODO From 51d3ec4a5b14279d666218a0702ab5dd2bd89283 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 09:59:52 +0800 Subject: [PATCH 0829/1089] more --- .../sglang/srt/managers/expert_distribution.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 02a4df47402..7d5b8ec2206 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -114,9 +114,9 @@ def _on_forward_pass_end(self, forward_pass_id: int): if not self._recording: return for gatherer_key, gatherer in self._single_pass_gatherers.items(): - single_pass_physical_count = gatherer.collect() + single_pass_global_physical_count = gatherer.collect() self._accumulator.append( - forward_pass_id, gatherer_key, single_pass_physical_count + forward_pass_id, gatherer_key, single_pass_global_physical_count ) def flush_buffer_depending_on_expert_location_metadata(self): @@ -357,7 +357,7 @@ def append( self, forward_pass_id: int, gatherer_key: str, - single_pass_physical_count: torch.Tensor, + single_pass_global_physical_count: torch.Tensor, ): raise NotImplementedError @@ -403,18 +403,18 @@ def append( self, forward_pass_id: int, gatherer_key: str, - single_pass_physical_count: torch.Tensor, + single_pass_global_physical_count: torch.Tensor, ): - single_pass_physical_count = single_pass_physical_count.to("cpu") + single_pass_global_physical_count = single_pass_global_physical_count.to("cpu") if self._save_dir is None: - single_pass_physical_count = single_pass_physical_count.tolist() + single_pass_global_physical_count = single_pass_global_physical_count.tolist() self._records.append( dict( forward_pass_id=forward_pass_id, rank=self._rank, gatherer_key=gatherer_key, - physical_count=single_pass_physical_count, + physical_count=single_pass_global_physical_count, ) ) @@ -482,10 +482,10 @@ def append( self, forward_pass_id: int, gatherer_key: str, - single_pass_physical_count: torch.Tensor, + single_pass_global_physical_count: torch.Tensor, ): # Can optimize if overhead here is large - self._physical_count += single_pass_physical_count + self._physical_count += single_pass_global_physical_count def reset(self): self._physical_count[...] = 0 From a437a35c5cc8dbb0c2374b333adf6cf4062fcd64 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 10:00:47 +0800 Subject: [PATCH 0830/1089] more --- python/sglang/srt/managers/expert_distribution.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 7d5b8ec2206..7ba27567cba 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -114,7 +114,7 @@ def _on_forward_pass_end(self, forward_pass_id: int): if not self._recording: return for gatherer_key, gatherer in self._single_pass_gatherers.items(): - single_pass_global_physical_count = gatherer.collect() + single_pass_global_physical_count = gatherer.collect_global_physical_count() self._accumulator.append( forward_pass_id, gatherer_key, single_pass_global_physical_count ) @@ -237,7 +237,7 @@ def on_deepep_dispatch_low_latency(self, layer_idx: int, recv_count: torch.Tenso def reset(self): raise NotImplementedError - def collect(self) -> torch.Tensor: + def collect_global_physical_count(self) -> torch.Tensor: raise NotImplementedError @@ -259,7 +259,7 @@ def _on_layer_data( def reset(self): self._num_recv_tokens_per_expert_list_of_layer.clear() - def collect(self) -> torch.Tensor: + def collect_global_physical_count(self) -> torch.Tensor: data = [ self._num_recv_tokens_per_expert_list_of_layer.get(layer_index) or ([0] * self._expert_location_metadata.num_local_physical_experts) @@ -313,7 +313,7 @@ def on_deepep_dispatch_low_latency(self, layer_idx: int, recv_count: torch.Tenso def reset(self): self._data[...] = 0 - def collect(self) -> torch.Tensor: + def collect_global_physical_count(self) -> torch.Tensor: return self._data From ed4e00366a60edc6ec27b4696b0ab3bceacd00bb Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 10:01:10 +0800 Subject: [PATCH 0831/1089] more --- python/sglang/srt/managers/expert_distribution.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 7ba27567cba..10773a755d6 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -244,24 +244,24 @@ def collect_global_physical_count(self) -> torch.Tensor: class _LayerBasedSinglePassGatherer(_SinglePassGatherer): def __init__(self, expert_location_metadata: "ExpertLocationMetadata"): super().__init__(expert_location_metadata) - self._num_recv_tokens_per_expert_list_of_layer = {} + self._objects_of_layer = {} def _on_layer_data( self, layer_idx: int, num_recv_tokens_per_expert_list: List[int] ): # TODO for TBO, we may need to relax this restriction - assert layer_idx not in self._num_recv_tokens_per_expert_list_of_layer + assert layer_idx not in self._objects_of_layer assert 0 <= layer_idx < self._expert_location_metadata.num_layers - self._num_recv_tokens_per_expert_list_of_layer[layer_idx] = ( + self._objects_of_layer[layer_idx] = ( num_recv_tokens_per_expert_list ) def reset(self): - self._num_recv_tokens_per_expert_list_of_layer.clear() + self._objects_of_layer.clear() def collect_global_physical_count(self) -> torch.Tensor: data = [ - self._num_recv_tokens_per_expert_list_of_layer.get(layer_index) + self._objects_of_layer.get(layer_index) or ([0] * self._expert_location_metadata.num_local_physical_experts) for layer_index in range(self._expert_location_metadata.num_layers) ] From 0e9994fd2cf7d3528f3b50cfb3c27dff0ee6760a Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 10:01:29 +0800 Subject: [PATCH 0832/1089] more --- python/sglang/srt/managers/expert_distribution.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 10773a755d6..f5f70a3bcb5 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -247,14 +247,11 @@ def __init__(self, expert_location_metadata: "ExpertLocationMetadata"): self._objects_of_layer = {} def _on_layer_data( - self, layer_idx: int, num_recv_tokens_per_expert_list: List[int] + self, layer_idx: int, objects: List[int] ): - # TODO for TBO, we may need to relax this restriction assert layer_idx not in self._objects_of_layer assert 0 <= layer_idx < self._expert_location_metadata.num_layers - self._objects_of_layer[layer_idx] = ( - num_recv_tokens_per_expert_list - ) + self._objects_of_layer[layer_idx] = objects def reset(self): self._objects_of_layer.clear() From baa46d17c8cc2e1fb369b0fd4daf2bab69d3fe5f Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 10:01:54 +0800 Subject: [PATCH 0833/1089] more --- python/sglang/srt/managers/expert_distribution.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index f5f70a3bcb5..2a58c1466c6 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -256,10 +256,10 @@ def _on_layer_data( def reset(self): self._objects_of_layer.clear() - def collect_global_physical_count(self) -> torch.Tensor: + def _collect_objects(self, pad_len: int) -> torch.Tensor: data = [ self._objects_of_layer.get(layer_index) - or ([0] * self._expert_location_metadata.num_local_physical_experts) + or ([0] * pad_len) for layer_index in range(self._expert_location_metadata.num_layers) ] return torch.tensor(data) From ea35d22bcc6674b44f5d53db2a7e3206204435f4 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 10:02:07 +0800 Subject: [PATCH 0834/1089] more --- python/sglang/srt/managers/expert_distribution.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 2a58c1466c6..e69d10923a7 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -285,6 +285,9 @@ def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor): self._on_layer_data(layer_idx, num_recv_tokens_per_expert_list) + def collect_global_physical_count(self) -> torch.Tensor: + return TODO + class _DeepepNormalSinglePassGatherer(_LayerBasedSinglePassGatherer): def on_deepep_dispatch_normal( @@ -293,6 +296,9 @@ def on_deepep_dispatch_normal( assert isinstance(num_recv_tokens_per_expert_list, list) self._on_layer_data(layer_idx, num_recv_tokens_per_expert_list) + def collect_global_physical_count(self) -> torch.Tensor: + return TODO + class _DeepepLowLatencySinglePassGatherer(_SinglePassGatherer): def __init__(self, expert_location_metadata: "ExpertLocationMetadata"): From e4a93cc05f0f922bf47cde74f87978ffa45967c0 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 10:02:28 +0800 Subject: [PATCH 0835/1089] more --- python/sglang/srt/managers/expert_distribution.py | 7 +------ python/sglang/srt/managers/expert_location.py | 3 --- 2 files changed, 1 insertion(+), 9 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index e69d10923a7..a3ce42ec1f7 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -276,12 +276,7 @@ def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor): ] * self._expert_location_metadata.num_local_physical_experts for token_record in topk_ids_list: for global_physical_expert_idx in token_record: - local_physical_expert_idx = ( - self._expert_location_metadata.physical_to_local_physical( - global_physical_expert_idx - ) - ) - num_recv_tokens_per_expert_list[local_physical_expert_idx] += 1 + num_recv_tokens_per_expert_list[global_physical_expert_idx] += 1 self._on_layer_data(layer_idx, num_recv_tokens_per_expert_list) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 092545c4802..75e76e2d5ee 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -194,9 +194,6 @@ def to(self, device): def local_physical_to_physical(self, rank: int, local_physical_expert_index: int): return self.num_local_physical_experts * rank + local_physical_expert_index - def physical_to_local_physical(self, global_physical_expert_index: int): - return global_physical_expert_index % self.num_local_physical_experts - def logical_to_all_physical( self, layer_id: int, logical_expert_id: int ) -> List[int]: From 121c96a433ae541353284163e7412dbab5c62726 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 10:02:49 +0800 Subject: [PATCH 0836/1089] more --- python/sglang/srt/managers/expert_distribution.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index a3ce42ec1f7..63fd2367ac7 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -271,14 +271,12 @@ def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor): topk_ids_list = topk_ids.to("cpu", non_blocking=True).numpy().tolist() torch.cuda.synchronize() - num_recv_tokens_per_expert_list = [ - 0 - ] * self._expert_location_metadata.num_local_physical_experts + global_physical_count = [0] * self._expert_location_metadata.num_physical_experts for token_record in topk_ids_list: for global_physical_expert_idx in token_record: - num_recv_tokens_per_expert_list[global_physical_expert_idx] += 1 + global_physical_count[global_physical_expert_idx] += 1 - self._on_layer_data(layer_idx, num_recv_tokens_per_expert_list) + self._on_layer_data(layer_idx, global_physical_count) def collect_global_physical_count(self) -> torch.Tensor: return TODO From 436e8e89c44b09fc9c129688859347cfcfa91ceb Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 10:03:09 +0800 Subject: [PATCH 0837/1089] more --- python/sglang/srt/managers/expert_distribution.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 63fd2367ac7..ae2b6ab8856 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -279,7 +279,7 @@ def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor): self._on_layer_data(layer_idx, global_physical_count) def collect_global_physical_count(self) -> torch.Tensor: - return TODO + return super()._collect_objects(pad_len=self._expert_location_metadata.num_physical_experts) class _DeepepNormalSinglePassGatherer(_LayerBasedSinglePassGatherer): From a57b0ad1b7cc18cf02110ec7ec271cf9d4f22468 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 10:03:24 +0800 Subject: [PATCH 0838/1089] more --- python/sglang/srt/managers/expert_distribution.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index ae2b6ab8856..6179e817609 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -290,6 +290,8 @@ def on_deepep_dispatch_normal( self._on_layer_data(layer_idx, num_recv_tokens_per_expert_list) def collect_global_physical_count(self) -> torch.Tensor: + local_physical_count = super()._collect_objects( + pad_len=self._expert_location_metadata.num_local_physical_experts) return TODO From 8c0ce1928ea0b8ce5495fab5d667cc4c416809ff Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 10:03:51 +0800 Subject: [PATCH 0839/1089] more --- python/sglang/srt/managers/expert_distribution.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 6179e817609..9732536c344 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -292,7 +292,7 @@ def on_deepep_dispatch_normal( def collect_global_physical_count(self) -> torch.Tensor: local_physical_count = super()._collect_objects( pad_len=self._expert_location_metadata.num_local_physical_experts) - return TODO + return _convert_local_to_global_physical_count(local_physical_count) class _DeepepLowLatencySinglePassGatherer(_SinglePassGatherer): From be2a76babd4c9a2651f8665f887f6082ebd95642 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 10:04:14 +0800 Subject: [PATCH 0840/1089] more --- python/sglang/srt/managers/expert_distribution.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 9732536c344..d15b9fcd2ea 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -312,7 +312,8 @@ def reset(self): self._data[...] = 0 def collect_global_physical_count(self) -> torch.Tensor: - return self._data + # Can optimize if bottleneck + return _convert_local_to_global_physical_count(self._data) # --------------------------------------- Accumulator ----------------------------------------- From d99f56b80e17fc869e186382bf8aaec489847f0d Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 10:05:16 +0800 Subject: [PATCH 0841/1089] more --- python/sglang/srt/managers/expert_distribution.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index d15b9fcd2ea..f744e82f365 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -316,6 +316,16 @@ def collect_global_physical_count(self) -> torch.Tensor: return _convert_local_to_global_physical_count(self._data) +def _convert_local_to_global_physical_count(local_physical_count: torch.Tensor) -> torch.Tensor: + dtype = local_physical_count.dtype + device = local_physical_count.device + num_layers, _ = local_physical_count.shape + + ans = torch.zeros((num_layers, TODO), dtype=dtype, device=device) + ans[:, TODO:TODO] = local_physical_count + return ans + + # --------------------------------------- Accumulator ----------------------------------------- _SINGLE_PASS_GATHERER_KEY_PRIMARY = "primary" From f760bdb0b15b2550c0e2adb9a2b6b3688726d80a Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 10:05:29 +0800 Subject: [PATCH 0842/1089] more --- python/sglang/srt/managers/expert_distribution.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index f744e82f365..54a70f9d220 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -322,7 +322,7 @@ def _convert_local_to_global_physical_count(local_physical_count: torch.Tensor) num_layers, _ = local_physical_count.shape ans = torch.zeros((num_layers, TODO), dtype=dtype, device=device) - ans[:, TODO:TODO] = local_physical_count + ans[:, num_local_physical_experts * rank:num_local_physical_experts * (rank + 1)] = local_physical_count return ans From 39bc97378cc81cd24d045bde00d03cd5c9981a2d Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 10:06:18 +0800 Subject: [PATCH 0843/1089] more --- .../sglang/srt/managers/expert_distribution.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 54a70f9d220..6cb50176c8a 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -86,7 +86,7 @@ def __init__( self._current_debug_name = Withable() self._accumulator = _Accumulator.init_new(expert_location_metadata, rank) self._single_pass_gatherers = { - k: _SinglePassGatherer.init_new(server_args, expert_location_metadata) + k: _SinglePassGatherer.init_new(server_args, expert_location_metadata, rank) for k in self._accumulator.get_single_pass_gatherer_keys() } @@ -208,20 +208,21 @@ def postprocess_dumps( class _SinglePassGatherer(ABC): @staticmethod def init_new( - server_args: ServerArgs, expert_location_metadata: "ExpertLocationMetadata" + server_args: ServerArgs, expert_location_metadata: "ExpertLocationMetadata", rank: int, ) -> "_SinglePassGatherer": if server_args.enable_deepep_moe: # `auto` has many restrictions now, so we lower the priority to implement low-latency capturing for auto if server_args.deepep_mode in ["normal", "auto"]: - return _DeepepNormalSinglePassGatherer(expert_location_metadata) + return _DeepepNormalSinglePassGatherer(expert_location_metadata, rank) elif server_args.deepep_mode == "low_latency": - return _DeepepLowLatencySinglePassGatherer(expert_location_metadata) + return _DeepepLowLatencySinglePassGatherer(expert_location_metadata, rank) else: raise NotImplementedError - return _SelectExpertsSinglePassGatherer(expert_location_metadata) + return _SelectExpertsSinglePassGatherer(expert_location_metadata, rank) - def __init__(self, expert_location_metadata: "ExpertLocationMetadata"): + def __init__(self, expert_location_metadata: "ExpertLocationMetadata", rank: int): self._expert_location_metadata = expert_location_metadata + self._rank = rank def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor): pass @@ -242,8 +243,8 @@ def collect_global_physical_count(self) -> torch.Tensor: class _LayerBasedSinglePassGatherer(_SinglePassGatherer): - def __init__(self, expert_location_metadata: "ExpertLocationMetadata"): - super().__init__(expert_location_metadata) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) self._objects_of_layer = {} def _on_layer_data( From ef61f217edff2bf29ef83e2607aa8321b784390c Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 10:06:47 +0800 Subject: [PATCH 0844/1089] more --- python/sglang/srt/managers/expert_distribution.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 6cb50176c8a..40d446fd62b 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -317,12 +317,17 @@ def collect_global_physical_count(self) -> torch.Tensor: return _convert_local_to_global_physical_count(self._data) -def _convert_local_to_global_physical_count(local_physical_count: torch.Tensor) -> torch.Tensor: +def _convert_local_to_global_physical_count( + local_physical_count: torch.Tensor, + rank: int, + num_local_physical_experts: int, + num_physical_experts: int, +) -> torch.Tensor: dtype = local_physical_count.dtype device = local_physical_count.device num_layers, _ = local_physical_count.shape - ans = torch.zeros((num_layers, TODO), dtype=dtype, device=device) + ans = torch.zeros((num_layers, num_physical_experts), dtype=dtype, device=device) ans[:, num_local_physical_experts * rank:num_local_physical_experts * (rank + 1)] = local_physical_count return ans From 8c5926de62c2702b472a47ecf924f1004c7ad6df Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 10:07:14 +0800 Subject: [PATCH 0845/1089] more --- python/sglang/srt/managers/expert_distribution.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 40d446fd62b..edf659b35bd 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -293,7 +293,12 @@ def on_deepep_dispatch_normal( def collect_global_physical_count(self) -> torch.Tensor: local_physical_count = super()._collect_objects( pad_len=self._expert_location_metadata.num_local_physical_experts) - return _convert_local_to_global_physical_count(local_physical_count) + return _convert_local_to_global_physical_count( + local_physical_count, + rank=self._rank, + num_local_physical_experts=self._expert_location_metadata.num_local_physical_experts, + num_physical_experts=self._expert_location_metadata.num_physical_experts, + ) class _DeepepLowLatencySinglePassGatherer(_SinglePassGatherer): @@ -314,7 +319,12 @@ def reset(self): def collect_global_physical_count(self) -> torch.Tensor: # Can optimize if bottleneck - return _convert_local_to_global_physical_count(self._data) + return _convert_local_to_global_physical_count( + self._data, + rank=self._rank, + num_local_physical_experts=self._expert_location_metadata.num_local_physical_experts, + num_physical_experts=self._expert_location_metadata.num_physical_experts, + ) def _convert_local_to_global_physical_count( From 4ff6c407ff7da41ff4f90cc179403860a42073ac Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 10:07:51 +0800 Subject: [PATCH 0846/1089] more --- python/sglang/srt/managers/expert_distribution.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index edf659b35bd..cbc18b5856c 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -302,10 +302,10 @@ def collect_global_physical_count(self) -> torch.Tensor: class _DeepepLowLatencySinglePassGatherer(_SinglePassGatherer): - def __init__(self, expert_location_metadata: "ExpertLocationMetadata"): - super().__init__(expert_location_metadata) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) self._data = torch.zeros( - (expert_location_metadata.num_layers, expert_location_metadata.num_local_physical_experts), + (self._expert_location_metadata.num_layers, self._expert_location_metadata.num_local_physical_experts), dtype=torch.int, device="cuda", ) From 8560a4da7ab7fc07a7bee0345104c509d3b8578e Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 10:08:41 +0800 Subject: [PATCH 0847/1089] more --- .../srt/managers/expert_distribution.py | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index cbc18b5856c..44d13c80536 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -47,10 +47,10 @@ def with_forward_pass(self, forward_pass_id: int): def on_select_experts(self, topk_ids: torch.Tensor): pass - def on_deepep_dispatch_normal(self, num_recv_tokens_per_expert_list: List[int]): + def on_deepep_dispatch_normal(self, local_physical_count_of_layer: List[int]): pass - def on_deepep_dispatch_low_latency(self, recv_count: torch.Tensor): + def on_deepep_dispatch_low_latency(self, local_physical_count_of_layer: torch.Tensor): pass def start_record(self): @@ -125,14 +125,14 @@ def flush_buffer_depending_on_expert_location_metadata(self): def on_select_experts(self, topk_ids: torch.Tensor): self._on_hook("on_select_experts", topk_ids=topk_ids) - def on_deepep_dispatch_normal(self, num_recv_tokens_per_expert_list: List[int]): + def on_deepep_dispatch_normal(self, local_physical_count_of_layer: List[int]): self._on_hook( "on_deepep_dispatch_normal", - num_recv_tokens_per_expert_list=num_recv_tokens_per_expert_list, + local_physical_count_of_layer=local_physical_count_of_layer, ) - def on_deepep_dispatch_low_latency(self, recv_count: torch.Tensor): - self._on_hook("on_deepep_dispatch_low_latency", recv_count=recv_count) + def on_deepep_dispatch_low_latency(self, local_physical_count_of_layer: torch.Tensor): + self._on_hook("on_deepep_dispatch_low_latency", local_physical_count_of_layer=local_physical_count_of_layer) def _on_hook(self, hook_name: str, **kwargs): if not (self._recording or torch.cuda.is_current_stream_capturing()): @@ -228,11 +228,11 @@ def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor): pass def on_deepep_dispatch_normal( - self, layer_idx: int, num_recv_tokens_per_expert_list: List[int] + self, layer_idx: int, local_physical_count_of_layer: List[int] ): pass - def on_deepep_dispatch_low_latency(self, layer_idx: int, recv_count: torch.Tensor): + def on_deepep_dispatch_low_latency(self, layer_idx: int, local_physical_count_of_layer: torch.Tensor): pass def reset(self): @@ -285,10 +285,10 @@ def collect_global_physical_count(self) -> torch.Tensor: class _DeepepNormalSinglePassGatherer(_LayerBasedSinglePassGatherer): def on_deepep_dispatch_normal( - self, layer_idx: int, num_recv_tokens_per_expert_list: List[int] + self, layer_idx: int, local_physical_count_of_layer: List[int] ): - assert isinstance(num_recv_tokens_per_expert_list, list) - self._on_layer_data(layer_idx, num_recv_tokens_per_expert_list) + assert isinstance(local_physical_count_of_layer, list) + self._on_layer_data(layer_idx, local_physical_count_of_layer) def collect_global_physical_count(self) -> torch.Tensor: local_physical_count = super()._collect_objects( @@ -310,9 +310,9 @@ def __init__(self, *args, **kwargs): device="cuda", ) - def on_deepep_dispatch_low_latency(self, layer_idx: int, recv_count: torch.Tensor): + def on_deepep_dispatch_low_latency(self, layer_idx: int, local_physical_count_of_layer: torch.Tensor): # Most naive implementation, can optimize later - self._data[layer_idx, :] = recv_count + self._data[layer_idx, :] = local_physical_count_of_layer def reset(self): self._data[...] = 0 From 6505e2bfa3a05a0ab491b885987de37c35ab1fd7 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 10:09:17 +0800 Subject: [PATCH 0848/1089] more --- python/sglang/srt/managers/expert_distribution.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 44d13c80536..e237f6a316a 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -466,6 +466,7 @@ def postprocess_dumps( physical_dumps: List[Any], expert_location_metadata: "ExpertLocationMetadata", ): + TODO logical_count = torch.zeros( ( expert_location_metadata.num_layers, @@ -496,6 +497,7 @@ def postprocess_dumps( def __init__(self, expert_location_metadata: "ExpertLocationMetadata", rank: int): super().__init__(expert_location_metadata, rank) + TODO self._physical_count = torch.zeros( ( self._expert_location_metadata.num_layers, @@ -509,13 +511,16 @@ def append( gatherer_key: str, single_pass_global_physical_count: torch.Tensor, ): + TODO # Can optimize if overhead here is large self._physical_count += single_pass_global_physical_count def reset(self): + TODO self._physical_count[...] = 0 def dump(self): + TODO return dict( rank=self._rank, physical_count=self._physical_count.tolist(), From 3d3b9f59fb6d2bfe8a9e31a01536075324737560 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 10:09:43 +0800 Subject: [PATCH 0849/1089] more --- python/sglang/srt/managers/expert_distribution.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index e237f6a316a..2726ffb1931 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -497,13 +497,18 @@ def postprocess_dumps( def __init__(self, expert_location_metadata: "ExpertLocationMetadata", rank: int): super().__init__(expert_location_metadata, rank) - TODO - self._physical_count = torch.zeros( + self._buffer_global_physical_count = torch.zeros( ( self._expert_location_metadata.num_layers, self._expert_location_metadata.num_local_physical_experts, ) ) + self._logical_count = torch.zeros( + ( + self._expert_location_metadata.num_layers, + self._expert_location_metadata.num_logical_experts, + ) + ) def append( self, From 09dee3d20b05ca2109caba98f852f1a9f177a5cf Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 10:10:03 +0800 Subject: [PATCH 0850/1089] more --- python/sglang/srt/managers/expert_distribution.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 2726ffb1931..b032375bb71 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -516,13 +516,12 @@ def append( gatherer_key: str, single_pass_global_physical_count: torch.Tensor, ): - TODO # Can optimize if overhead here is large - self._physical_count += single_pass_global_physical_count + self._buffer_global_physical_count += single_pass_global_physical_count def reset(self): - TODO - self._physical_count[...] = 0 + self._buffer_global_physical_count[...] = 0 + self._logical_count[...] = 0 def dump(self): TODO From 0efd3a881853c58084df2845627e5b9b63799557 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 10:10:48 +0800 Subject: [PATCH 0851/1089] more --- python/sglang/srt/managers/expert_distribution.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index b032375bb71..5b39b1fde64 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -490,6 +490,7 @@ def postprocess_dumps( layer_index, global_physical_expert_index ] ) + TODO_no_physical_count_key logical_count[layer_index, logical_expert_index] += physical_dump[ "physical_count" ][layer_index][local_physical_expert_index] @@ -524,10 +525,11 @@ def reset(self): self._logical_count[...] = 0 def dump(self): - TODO + self.flush_buffer_depending_on_expert_location_metadata() + return dict( rank=self._rank, - physical_count=self._physical_count.tolist(), + logical_count=self._logical_count.tolist(), ) def flush_buffer_depending_on_expert_location_metadata(self): From 9b9f06eddbd4409b04c13dd37af315a2991ea93e Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 10:12:17 +0800 Subject: [PATCH 0852/1089] more --- python/sglang/srt/managers/expert_distribution.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 5b39b1fde64..cc3cf886208 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -466,7 +466,11 @@ def postprocess_dumps( physical_dumps: List[Any], expert_location_metadata: "ExpertLocationMetadata", ): - TODO + TODO_rename_physical_dumps + logical_count = torch.stack([item["logical_count"] for item in physical_dumps]).sum(dim=0) + return dict(logical_count=logical_count.tolist()) + + TODO_remove_below logical_count = torch.zeros( ( expert_location_metadata.num_layers, From b546f0734950344fae2bbcb88d7776d44bb6d925 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 10:12:43 +0800 Subject: [PATCH 0853/1089] more --- .../srt/managers/expert_distribution.py | 62 ++++++++++--------- 1 file changed, 32 insertions(+), 30 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index cc3cf886208..8e4c0029186 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -470,36 +470,6 @@ def postprocess_dumps( logical_count = torch.stack([item["logical_count"] for item in physical_dumps]).sum(dim=0) return dict(logical_count=logical_count.tolist()) - TODO_remove_below - logical_count = torch.zeros( - ( - expert_location_metadata.num_layers, - expert_location_metadata.num_logical_experts, - ) - ) - # Most naive implementation, can optimize if it is bottleneck - for physical_dump in physical_dumps: - for layer_index in range(expert_location_metadata.num_layers): - for local_physical_expert_index in range( - expert_location_metadata.num_local_physical_experts - ): - global_physical_expert_index = ( - expert_location_metadata.local_physical_to_physical( - rank=physical_dump["rank"], - local_physical_expert_index=local_physical_expert_index, - ) - ) - logical_expert_index = ( - expert_location_metadata.physical_to_logical_map[ - layer_index, global_physical_expert_index - ] - ) - TODO_no_physical_count_key - logical_count[layer_index, logical_expert_index] += physical_dump[ - "physical_count" - ][layer_index][local_physical_expert_index] - return dict(logical_count=logical_count.tolist()) - def __init__(self, expert_location_metadata: "ExpertLocationMetadata", rank: int): super().__init__(expert_location_metadata, rank) self._buffer_global_physical_count = torch.zeros( @@ -538,3 +508,35 @@ def dump(self): def flush_buffer_depending_on_expert_location_metadata(self): TODO + + +TODO +""" + logical_count = torch.zeros( + ( + expert_location_metadata.num_layers, + expert_location_metadata.num_logical_experts, + ) + ) + # Most naive implementation, can optimize if it is bottleneck + for physical_dump in physical_dumps: + for layer_index in range(expert_location_metadata.num_layers): + for local_physical_expert_index in range( + expert_location_metadata.num_local_physical_experts + ): + global_physical_expert_index = ( + expert_location_metadata.local_physical_to_physical( + rank=physical_dump["rank"], + local_physical_expert_index=local_physical_expert_index, + ) + ) + logical_expert_index = ( + expert_location_metadata.physical_to_logical_map[ + layer_index, global_physical_expert_index + ] + ) + TODO_no_physical_count_key + logical_count[layer_index, logical_expert_index] += physical_dump[ + "physical_count" + ][layer_index][local_physical_expert_index] +""" From 394bd526777c30e575fa4fd1fc71231633cecce9 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 10:13:22 +0800 Subject: [PATCH 0854/1089] more --- python/sglang/srt/managers/expert_distribution.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 8e4c0029186..901565ddc7e 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -507,7 +507,12 @@ def dump(self): ) def flush_buffer_depending_on_expert_location_metadata(self): - TODO + self._logical_count += _convert_global_physical_count_to_logical_count(self._buffer_global_physical_count) + self._buffer_global_physical_count[...] = 0 + + +def _convert_global_physical_count_to_logical_count(global_physical_count: torch.Tensor): + return TODO TODO From f2238895f4adf2c4d5e57977a58fdd22c932025e Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 10:14:04 +0800 Subject: [PATCH 0855/1089] more --- .../srt/managers/expert_distribution.py | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 901565ddc7e..680095a6382 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -195,10 +195,10 @@ def set_global_expert_distribution_recorder(value): def postprocess_dumps( - physical_dumps: List[Any], expert_location_metadata: "ExpertLocationMetadata" + dumps: List[Any], expert_location_metadata: "ExpertLocationMetadata" ): return _Accumulator.get_class().postprocess_dumps( - physical_dumps, expert_location_metadata + dumps, expert_location_metadata ) @@ -373,7 +373,7 @@ def get_single_pass_gatherer_key(self, debug_name: Optional[str]): @classmethod def postprocess_dumps( cls, - physical_dumps: List[Any], + dumps: List[Any], expert_location_metadata: "ExpertLocationMetadata", ): raise NotImplementedError @@ -400,11 +400,11 @@ class _DetailAccumulator(_Accumulator): @classmethod def postprocess_dumps( cls, - physical_dumps: List[Any], + dumps: List[Any], expert_location_metadata: "ExpertLocationMetadata", ): # Do not convert to logical since we want all details - return [record for physical_dump in physical_dumps for record in physical_dump] + return [record for dump in dumps for record in dump] def __init__(self, expert_location_metadata: "ExpertLocationMetadata", rank: int): super().__init__(expert_location_metadata, rank) @@ -463,11 +463,10 @@ class _StatAccumulator(_Accumulator): @classmethod def postprocess_dumps( cls, - physical_dumps: List[Any], + dumps: List[Any], expert_location_metadata: "ExpertLocationMetadata", ): - TODO_rename_physical_dumps - logical_count = torch.stack([item["logical_count"] for item in physical_dumps]).sum(dim=0) + logical_count = torch.stack([item["logical_count"] for item in dumps]).sum(dim=0) return dict(logical_count=logical_count.tolist()) def __init__(self, expert_location_metadata: "ExpertLocationMetadata", rank: int): @@ -524,14 +523,14 @@ def _convert_global_physical_count_to_logical_count(global_physical_count: torch ) ) # Most naive implementation, can optimize if it is bottleneck - for physical_dump in physical_dumps: + for dump in dumps: for layer_index in range(expert_location_metadata.num_layers): for local_physical_expert_index in range( expert_location_metadata.num_local_physical_experts ): global_physical_expert_index = ( expert_location_metadata.local_physical_to_physical( - rank=physical_dump["rank"], + rank=dump["rank"], local_physical_expert_index=local_physical_expert_index, ) ) @@ -541,7 +540,7 @@ def _convert_global_physical_count_to_logical_count(global_physical_count: torch ] ) TODO_no_physical_count_key - logical_count[layer_index, logical_expert_index] += physical_dump[ + logical_count[layer_index, logical_expert_index] += dump[ "physical_count" ][layer_index][local_physical_expert_index] """ From 65947e3f65fc6d9b746be081bb598adae171d2c6 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 10:14:19 +0800 Subject: [PATCH 0856/1089] more --- .../srt/managers/expert_distribution.py | 57 +++++++++---------- 1 file changed, 26 insertions(+), 31 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 680095a6382..ea86bcb1963 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -511,36 +511,31 @@ def flush_buffer_depending_on_expert_location_metadata(self): def _convert_global_physical_count_to_logical_count(global_physical_count: torch.Tensor): - return TODO - - -TODO -""" - logical_count = torch.zeros( - ( - expert_location_metadata.num_layers, - expert_location_metadata.num_logical_experts, - ) + logical_count = torch.zeros( + ( + expert_location_metadata.num_layers, + expert_location_metadata.num_logical_experts, ) - # Most naive implementation, can optimize if it is bottleneck - for dump in dumps: - for layer_index in range(expert_location_metadata.num_layers): - for local_physical_expert_index in range( - expert_location_metadata.num_local_physical_experts - ): - global_physical_expert_index = ( - expert_location_metadata.local_physical_to_physical( - rank=dump["rank"], - local_physical_expert_index=local_physical_expert_index, - ) - ) - logical_expert_index = ( - expert_location_metadata.physical_to_logical_map[ - layer_index, global_physical_expert_index - ] + ) + # Most naive implementation, can optimize if it is bottleneck + for dump in dumps: + for layer_index in range(expert_location_metadata.num_layers): + for local_physical_expert_index in range( + expert_location_metadata.num_local_physical_experts + ): + global_physical_expert_index = ( + expert_location_metadata.local_physical_to_physical( + rank=dump["rank"], + local_physical_expert_index=local_physical_expert_index, ) - TODO_no_physical_count_key - logical_count[layer_index, logical_expert_index] += dump[ - "physical_count" - ][layer_index][local_physical_expert_index] -""" + ) + logical_expert_index = ( + expert_location_metadata.physical_to_logical_map[ + layer_index, global_physical_expert_index + ] + ) + TODO_no_physical_count_key + logical_count[layer_index, logical_expert_index] += dump[ + "physical_count" + ][layer_index][local_physical_expert_index] + return logical_count From 0ce12ff26bc9bba25e43bf2f1c34c79ae11b06ca Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 10:14:48 +0800 Subject: [PATCH 0857/1089] more --- .../srt/managers/expert_distribution.py | 42 ++++++++++--------- 1 file changed, 22 insertions(+), 20 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index ea86bcb1963..9b67fa8c069 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -510,7 +510,10 @@ def flush_buffer_depending_on_expert_location_metadata(self): self._buffer_global_physical_count[...] = 0 -def _convert_global_physical_count_to_logical_count(global_physical_count: torch.Tensor): +def _convert_global_physical_count_to_logical_count( + global_physical_count: torch.Tensor, + expert_location_metadata: ExpertLocationMetadata, +): logical_count = torch.zeros( ( expert_location_metadata.num_layers, @@ -518,24 +521,23 @@ def _convert_global_physical_count_to_logical_count(global_physical_count: torch ) ) # Most naive implementation, can optimize if it is bottleneck - for dump in dumps: - for layer_index in range(expert_location_metadata.num_layers): - for local_physical_expert_index in range( - expert_location_metadata.num_local_physical_experts - ): - global_physical_expert_index = ( - expert_location_metadata.local_physical_to_physical( - rank=dump["rank"], - local_physical_expert_index=local_physical_expert_index, - ) - ) - logical_expert_index = ( - expert_location_metadata.physical_to_logical_map[ - layer_index, global_physical_expert_index - ] + for layer_index in range(expert_location_metadata.num_layers): + for local_physical_expert_index in range( + expert_location_metadata.num_local_physical_experts + ): + global_physical_expert_index = ( + expert_location_metadata.local_physical_to_physical( + rank=dump["rank"], + local_physical_expert_index=local_physical_expert_index, ) - TODO_no_physical_count_key - logical_count[layer_index, logical_expert_index] += dump[ - "physical_count" - ][layer_index][local_physical_expert_index] + ) + logical_expert_index = ( + expert_location_metadata.physical_to_logical_map[ + layer_index, global_physical_expert_index + ] + ) + TODO_no_physical_count_key + logical_count[layer_index, logical_expert_index] += dump[ + "physical_count" + ][layer_index][local_physical_expert_index] return logical_count From af4a7cc80bf43c01b4a5b1e287fc7a068a153cbc Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 10:15:23 +0800 Subject: [PATCH 0858/1089] more --- .../sglang/srt/managers/expert_distribution.py | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 9b67fa8c069..104033733a5 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -522,22 +522,12 @@ def _convert_global_physical_count_to_logical_count( ) # Most naive implementation, can optimize if it is bottleneck for layer_index in range(expert_location_metadata.num_layers): - for local_physical_expert_index in range( - expert_location_metadata.num_local_physical_experts - ): - global_physical_expert_index = ( - expert_location_metadata.local_physical_to_physical( - rank=dump["rank"], - local_physical_expert_index=local_physical_expert_index, - ) - ) + for global_physical_expert_index in range(expert_location_metadata.num_physical_experts): logical_expert_index = ( expert_location_metadata.physical_to_logical_map[ layer_index, global_physical_expert_index ] ) - TODO_no_physical_count_key - logical_count[layer_index, logical_expert_index] += dump[ - "physical_count" - ][layer_index][local_physical_expert_index] + logical_count[layer_index, logical_expert_index] += global_physical_count[layer_index][ + local_physical_expert_index] return logical_count From 8fa7258ad736b81cf99dee5a8f36809bab8ba8ad Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 10:15:56 +0800 Subject: [PATCH 0859/1089] more --- python/sglang/srt/managers/expert_distribution.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 104033733a5..7da177999f3 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -528,6 +528,6 @@ def _convert_global_physical_count_to_logical_count( layer_index, global_physical_expert_index ] ) - logical_count[layer_index, logical_expert_index] += global_physical_count[layer_index][ - local_physical_expert_index] + logical_count[layer_index, logical_expert_index] += global_physical_count[ + layer_index, global_physical_expert_index] return logical_count From c8f0ecad30b94f21759e00bd94fc64f761abfd06 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 10:16:54 +0800 Subject: [PATCH 0860/1089] more --- python/sglang/srt/managers/eplb_simulator.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/managers/eplb_simulator.py b/python/sglang/srt/managers/eplb_simulator.py index 4244a3fb433..d225bb9802b 100644 --- a/python/sglang/srt/managers/eplb_simulator.py +++ b/python/sglang/srt/managers/eplb_simulator.py @@ -86,13 +86,8 @@ def read_physical_count_of_forward_pass(dir_data: Path): for forward_pass_id, physical_count_of_rank in sorted( physical_count_of_forward_pass_id_and_rank.items() ): - physical_count_of_rank_tensor = torch.cat( - [ - physical_count - for rank, physical_count in sorted(physical_count_of_rank.items()) - ], - dim=-1, - ) + physical_count_of_rank_tensor = torch.stack( + [physical_count for rank, physical_count in sorted(physical_count_of_rank.items())]).sum(dim=0) items.append(physical_count_of_rank_tensor) physical_count_of_forward_pass = torch.stack(items) From 53f143545a9a71f091f6aa816f4bbd970ef7c9d2 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 10:20:32 +0800 Subject: [PATCH 0861/1089] more --- python/sglang/srt/model_executor/expert_location_updater.py | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 python/sglang/srt/model_executor/expert_location_updater.py diff --git a/python/sglang/srt/model_executor/expert_location_updater.py b/python/sglang/srt/model_executor/expert_location_updater.py new file mode 100644 index 00000000000..51b3b4e0f42 --- /dev/null +++ b/python/sglang/srt/model_executor/expert_location_updater.py @@ -0,0 +1,2 @@ +class ExpertLocationUpdater: + pass From 19633b21aca11b263ed46d2b41e0d0b23028266a Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 10:21:42 +0800 Subject: [PATCH 0862/1089] more --- .../model_executor/expert_location_updater.py | 42 ++++++++++++++++++- .../sglang/srt/model_executor/model_runner.py | 23 +--------- 2 files changed, 42 insertions(+), 23 deletions(-) diff --git a/python/sglang/srt/model_executor/expert_location_updater.py b/python/sglang/srt/model_executor/expert_location_updater.py index 51b3b4e0f42..6e152e8c887 100644 --- a/python/sglang/srt/model_executor/expert_location_updater.py +++ b/python/sglang/srt/model_executor/expert_location_updater.py @@ -1,2 +1,42 @@ +import logging +from typing import TYPE_CHECKING + +import torch +from sglang.srt.managers.expert_distribution import get_global_expert_distribution_recorder +from sglang.srt.managers.io_struct import UpdateExpertLocationReqInput +from sglang.srt.managers.schedule_batch import get_global_expert_location_metadata +from sglang.srt.utils import get_bool_env_var + +if TYPE_CHECKING: + from sglang.srt.model_executor.model_runner import ModelRunner + +logger = logging.getLogger(__name__) + + class ExpertLocationUpdater: - pass + def __init__(self, model_runner: "ModelRunner"): + self._model_runner = model_runner + + def act(self, recv_req: UpdateExpertLocationReqInput): + logger.info("update_expert_location start") + torch.distributed.barrier() + + get_global_expert_distribution_recorder().flush_buffer_depending_on_expert_location_metadata() + + get_global_expert_location_metadata().update(recv_req.expert_location_metadata) + if self._model_runner.tp_rank == 0 and get_bool_env_var( + "SGLANG_LOG_EXPERT_LOCATION_METADATA" + ): + logger.info( + f"Updated expert_location_metadata: {get_global_expert_location_metadata().debug_str()}" + ) + + # We may be able to further reduce lock time by faster copying, pre-transfering, etc + self._model_runner.update_weights_from_disk( + model_path=self._model_runner.model_config.model_path, + load_format=self._model_runner.server_args.load_format, + param_categories=["moe"], + ) + + logger.info("update_expert_location end") + torch.distributed.barrier() diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 292c2092b4c..221367187b6 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -493,28 +493,7 @@ def load_model(self): ) from None def update_expert_location(self, recv_req: UpdateExpertLocationReqInput): - logger.info("update_expert_location start") - torch.distributed.barrier() - - get_global_expert_distribution_recorder().flush_buffer_depending_on_expert_location_metadata() - - get_global_expert_location_metadata().update(recv_req.expert_location_metadata) - if self.tp_rank == 0 and get_bool_env_var( - "SGLANG_LOG_EXPERT_LOCATION_METADATA" - ): - logger.info( - f"Updated expert_location_metadata: {get_global_expert_location_metadata().debug_str()}" - ) - - # We may be able to further reduce lock time by faster copying, pre-transfering, etc - self.update_weights_from_disk( - model_path=self.model_config.model_path, - load_format=self.server_args.load_format, - param_categories=["moe"], - ) - - logger.info("update_expert_location end") - torch.distributed.barrier() + TODO def update_weights_from_disk( self, model_path: str, load_format: str, param_categories: Optional[List[str]] From 46ebc09116a6b92a8be466caa863107288393f79 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 10:22:18 +0800 Subject: [PATCH 0863/1089] more --- python/sglang/srt/model_executor/model_runner.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 221367187b6..d8af8b32688 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -62,6 +62,7 @@ ) from sglang.srt.mem_cache.paged_allocator import PagedTokenToKVPoolAllocator from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner +from sglang.srt.model_executor.expert_location_updater import ExpertLocationUpdater from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader import get_model from sglang.srt.model_loader.loader import ( @@ -206,6 +207,8 @@ def __init__( # If it is a draft model tp_group can be different. self.initialize(min_per_gpu_memory) + self._expert_location_updater = ExpertLocationUpdater(self) + def initialize(self, min_per_gpu_memory: float): server_args = self.server_args self.memory_saver_adapter = TorchMemorySaverAdapter.create( @@ -493,7 +496,7 @@ def load_model(self): ) from None def update_expert_location(self, recv_req: UpdateExpertLocationReqInput): - TODO + self._expert_location_updater.act(recv_req) def update_weights_from_disk( self, model_path: str, load_format: str, param_categories: Optional[List[str]] From 647b0acb498254d260862bc590a7b9081d17d947 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 10:39:54 +0800 Subject: [PATCH 0864/1089] more --- python/sglang/srt/managers/tokenizer_manager.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 378c0b93479..1a3359f6910 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -707,7 +707,10 @@ async def update_expert_location(self, obj: UpdateExpertLocationReqInput): assert ( self.server_args.enable_scheduler_input_blocker and (self.server_args.ep_dispatch_algorithm is not None) ), f"update_expert_location requires enable_scheduler_input_blocker and non-null ep_dispatch_algorithm" + + TODO + async def _update_expert_location_raw(self, obj: UpdateExpertLocationReqInput): self.expert_location_metadata = None self._send_block_request(BlockReqType.BLOCK) From 16582dd92defadabe24b9886fca50cf05daec611 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 10:40:32 +0800 Subject: [PATCH 0865/1089] more --- python/sglang/srt/managers/tokenizer_manager.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 1a3359f6910..6227e9a217e 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -707,8 +707,12 @@ async def update_expert_location(self, obj: UpdateExpertLocationReqInput): assert ( self.server_args.enable_scheduler_input_blocker and (self.server_args.ep_dispatch_algorithm is not None) ), f"update_expert_location requires enable_scheduler_input_blocker and non-null ep_dispatch_algorithm" - - TODO + + old_expert_location_metadata = copy.deepcopy(self.expert_location_metadata) + for TODO in TODO: + await self._update_expert_location_raw(UpdateExpertLocationReqInput( + expert_location_metadata=TODO, + )) async def _update_expert_location_raw(self, obj: UpdateExpertLocationReqInput): self.expert_location_metadata = None From 34d9e141511830bbb4d52a305d782707cace63d0 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 10:41:12 +0800 Subject: [PATCH 0866/1089] more --- python/sglang/srt/managers/tokenizer_manager.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 6227e9a217e..638f25e3939 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -709,9 +709,11 @@ async def update_expert_location(self, obj: UpdateExpertLocationReqInput): ), f"update_expert_location requires enable_scheduler_input_blocker and non-null ep_dispatch_algorithm" old_expert_location_metadata = copy.deepcopy(self.expert_location_metadata) - for TODO in TODO: + for what in TODO: + partial_expert_location_metadata = copy.deepcopy(old_expert_location_metadata) + partial_expert_location_metadata.update(obj.expert_location_metadata, layer_ids=TODO) await self._update_expert_location_raw(UpdateExpertLocationReqInput( - expert_location_metadata=TODO, + expert_location_metadata=partial_expert_location_metadata, )) async def _update_expert_location_raw(self, obj: UpdateExpertLocationReqInput): From 9666260b757b3eeabe7eec0c47a2c79bd7f54f90 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 10:41:59 +0800 Subject: [PATCH 0867/1089] more --- python/sglang/srt/managers/expert_location.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 75e76e2d5ee..a5a6e836cd1 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -166,7 +166,8 @@ def _init_raw( # -------------------------------- mutation ------------------------------------ - def update(self, other: "ExpertLocationMetadata"): + def update(self, other: "ExpertLocationMetadata", layer_id_start: Optional[int] = None, + layer_id_len: Optional[int] = None): for field in [ "ep_size", ]: From 8118d2b1a7149a7cb8a4fdb197f479b41335aab6 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 10:42:39 +0800 Subject: [PATCH 0868/1089] more --- python/sglang/srt/managers/expert_location.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index a5a6e836cd1..97d45a8115d 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -173,11 +173,11 @@ def update(self, other: "ExpertLocationMetadata", layer_id_start: Optional[int] ]: assert getattr(self, field) == getattr(other, field) - for field in [ - "physical_to_logical_map", - "logical_to_all_physical_map", - "logical_to_all_physical_map_num_valid", - "logical_to_rank_dispatch_physical_map", + for field, layer_id_dim in [ + ("physical_to_logical_map", 0), + ("logical_to_all_physical_map", 0), + ("logical_to_all_physical_map_num_valid", 0), + ("logical_to_rank_dispatch_physical_map", 1), ]: # Cannot update address to avoid breaking CUDA graph getattr(self, field)[...] = getattr(other, field) From 8bc37cff5d42f1ecc6002016ad6ef259266c4ba1 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 10:43:33 +0800 Subject: [PATCH 0869/1089] more --- python/sglang/srt/managers/expert_location.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 97d45a8115d..3a981961664 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -179,6 +179,12 @@ def update(self, other: "ExpertLocationMetadata", layer_id_start: Optional[int] ("logical_to_all_physical_map_num_valid", 0), ("logical_to_rank_dispatch_physical_map", 1), ]: + def _get(obj): + ans = getattr(obj, field) + if (layer_id_start is not None) or (layer_id_len is not None): + ans = ans.narrow(dim=layer_id_dim, start=layer_id_start, length=layer_id_len) + return ans + # Cannot update address to avoid breaking CUDA graph getattr(self, field)[...] = getattr(other, field) From 59d25f837e1e9e925fbda0d2a169046b2332dda6 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 10:43:52 +0800 Subject: [PATCH 0870/1089] more --- python/sglang/srt/managers/expert_location.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 3a981961664..b5d03f11d69 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -186,7 +186,8 @@ def _get(obj): return ans # Cannot update address to avoid breaking CUDA graph - getattr(self, field)[...] = getattr(other, field) + dst = _get(self) + dst[...] = _get(other) def to(self, device): for field in [ From afcdca8952f6c7ccfe7c44b449096b51e8a1c0fb Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 10:45:01 +0800 Subject: [PATCH 0871/1089] more --- python/sglang/srt/managers/tokenizer_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 638f25e3939..42ff7d0fb76 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -711,7 +711,7 @@ async def update_expert_location(self, obj: UpdateExpertLocationReqInput): old_expert_location_metadata = copy.deepcopy(self.expert_location_metadata) for what in TODO: partial_expert_location_metadata = copy.deepcopy(old_expert_location_metadata) - partial_expert_location_metadata.update(obj.expert_location_metadata, layer_ids=TODO) + partial_expert_location_metadata.update(obj.expert_location_metadata, layer_id_start=0, layer_id_len=TODO) await self._update_expert_location_raw(UpdateExpertLocationReqInput( expert_location_metadata=partial_expert_location_metadata, )) From eb5cb9f574cf899e548b1fb34854e8bf73baa4a2 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 10:45:32 +0800 Subject: [PATCH 0872/1089] more --- python/sglang/srt/managers/tokenizer_manager.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 42ff7d0fb76..6de26008535 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -719,10 +719,12 @@ async def update_expert_location(self, obj: UpdateExpertLocationReqInput): async def _update_expert_location_raw(self, obj: UpdateExpertLocationReqInput): self.expert_location_metadata = None + TODO_prepare + + TODO_rename_to_act self._send_block_request(BlockReqType.BLOCK) await self.update_expert_location_communicator.call_send(obj) self._send_block_request(BlockReqType.UNBLOCK) - await self.update_expert_location_communicator.call_await() self.expert_location_metadata = obj.expert_location_metadata From 286b4f82bff39582eb292dee038bcc86d720c9ab Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 10:46:54 +0800 Subject: [PATCH 0873/1089] more --- .../sglang/srt/managers/tokenizer_manager.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 6de26008535..bd47edfb338 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -709,25 +709,30 @@ async def update_expert_location(self, obj: UpdateExpertLocationReqInput): ), f"update_expert_location requires enable_scheduler_input_blocker and non-null ep_dispatch_algorithm" old_expert_location_metadata = copy.deepcopy(self.expert_location_metadata) - for what in TODO: + num_layers = old_expert_location_metadata.num_layers + + layer_id_lens = list(range(0, num_layers, 10)) + [num_layers] + + for layer_id_len in layer_id_lens: partial_expert_location_metadata = copy.deepcopy(old_expert_location_metadata) - partial_expert_location_metadata.update(obj.expert_location_metadata, layer_id_start=0, layer_id_len=TODO) - await self._update_expert_location_raw(UpdateExpertLocationReqInput( + partial_expert_location_metadata.update(obj.expert_location_metadata, layer_id_start=0, + layer_id_len=layer_id_len) + await self._update_expert_location_raw( expert_location_metadata=partial_expert_location_metadata, - )) + ) - async def _update_expert_location_raw(self, obj: UpdateExpertLocationReqInput): + async def _update_expert_location_raw(self, expert_location_metadata: ExpertLocationMetadata): self.expert_location_metadata = None TODO_prepare TODO_rename_to_act self._send_block_request(BlockReqType.BLOCK) - await self.update_expert_location_communicator.call_send(obj) + await self.update_expert_location_communicator.call_send(TODO) self._send_block_request(BlockReqType.UNBLOCK) await self.update_expert_location_communicator.call_await() - self.expert_location_metadata = obj.expert_location_metadata + self.expert_location_metadata = expert_location_metadata async def update_weights_from_disk( self, From a32010f8f9bae151937a0d7a672082b949cfec59 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:02:03 +0800 Subject: [PATCH 0874/1089] more --- python/sglang/srt/model_executor/expert_location_updater.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/sglang/srt/model_executor/expert_location_updater.py b/python/sglang/srt/model_executor/expert_location_updater.py index 6e152e8c887..6d55b2db8eb 100644 --- a/python/sglang/srt/model_executor/expert_location_updater.py +++ b/python/sglang/srt/model_executor/expert_location_updater.py @@ -17,6 +17,9 @@ class ExpertLocationUpdater: def __init__(self, model_runner: "ModelRunner"): self._model_runner = model_runner + def prepare(self): + TODO + def act(self, recv_req: UpdateExpertLocationReqInput): logger.info("update_expert_location start") torch.distributed.barrier() From fbbd6cc9d38856772818f705a043f388ec832b7c Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:02:26 +0800 Subject: [PATCH 0875/1089] more --- python/sglang/srt/model_executor/model_weight_updater.py | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 python/sglang/srt/model_executor/model_weight_updater.py diff --git a/python/sglang/srt/model_executor/model_weight_updater.py b/python/sglang/srt/model_executor/model_weight_updater.py new file mode 100644 index 00000000000..b22118c4ed8 --- /dev/null +++ b/python/sglang/srt/model_executor/model_weight_updater.py @@ -0,0 +1,2 @@ +class ModelWeightUpdater: + TODO From 1a4f05b758cf37fdf546f2eebdedd59bd1896337 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:03:28 +0800 Subject: [PATCH 0876/1089] more --- python/sglang/srt/model_executor/memory_transfer.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 python/sglang/srt/model_executor/memory_transfer.py diff --git a/python/sglang/srt/model_executor/memory_transfer.py b/python/sglang/srt/model_executor/memory_transfer.py new file mode 100644 index 00000000000..f0d9abf0d9f --- /dev/null +++ b/python/sglang/srt/model_executor/memory_transfer.py @@ -0,0 +1,13 @@ +from typing import List, Tuple + +import torch + +NamedTensors = List[Tuple[str, torch.Tensor]] + + +class TensorOperationManagerBase: + def enqueue(self, named_tensors: NamedTensors): + raise NotImplementedError + + def get_outputs(self) -> List[NamedTensors]: + raise NotImplementedError From b53ade6cd636abc3074c2c2aadcf7a48f2f4442e Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:03:56 +0800 Subject: [PATCH 0877/1089] more --- python/sglang/srt/model_executor/memory_transfer.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/python/sglang/srt/model_executor/memory_transfer.py b/python/sglang/srt/model_executor/memory_transfer.py index f0d9abf0d9f..ca02886ca02 100644 --- a/python/sglang/srt/model_executor/memory_transfer.py +++ b/python/sglang/srt/model_executor/memory_transfer.py @@ -11,3 +11,11 @@ def enqueue(self, named_tensors: NamedTensors): def get_outputs(self) -> List[NamedTensors]: raise NotImplementedError + + +class AsyncPinMemoryManager(TensorOperationManagerBase): + TODO + + +class AsyncToCudaManager(TensorOperationManagerBase): + TODO From 840dedd7a013c44015313be87aefa36d9dedaf7d Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:04:20 +0800 Subject: [PATCH 0878/1089] more --- .../srt/model_executor/memory_transfer.py | 20 +++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/model_executor/memory_transfer.py b/python/sglang/srt/model_executor/memory_transfer.py index ca02886ca02..99481c55abd 100644 --- a/python/sglang/srt/model_executor/memory_transfer.py +++ b/python/sglang/srt/model_executor/memory_transfer.py @@ -13,9 +13,25 @@ def get_outputs(self) -> List[NamedTensors]: raise NotImplementedError +class CombinedManager(TensorOperationManagerBase): + def enqueue(self, named_tensors: NamedTensors): + TODO + + def get_outputs(self) -> List[NamedTensors]: + return TODO + + class AsyncPinMemoryManager(TensorOperationManagerBase): - TODO + def enqueue(self, named_tensors: NamedTensors): + TODO + + def get_outputs(self) -> List[NamedTensors]: + return TODO class AsyncToCudaManager(TensorOperationManagerBase): - TODO + def enqueue(self, named_tensors: NamedTensors): + TODO + + def get_outputs(self) -> List[NamedTensors]: + return TODO From c450db49321e13469a15901dea1279f7569c974e Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:04:44 +0800 Subject: [PATCH 0879/1089] more --- python/sglang/srt/model_executor/memory_transfer.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/sglang/srt/model_executor/memory_transfer.py b/python/sglang/srt/model_executor/memory_transfer.py index 99481c55abd..d17dabcf224 100644 --- a/python/sglang/srt/model_executor/memory_transfer.py +++ b/python/sglang/srt/model_executor/memory_transfer.py @@ -14,6 +14,10 @@ def get_outputs(self) -> List[NamedTensors]: class CombinedManager(TensorOperationManagerBase): + @classmethod + def init_pin_memory_and_to_cuda(cls): + return cls(manager_a=AsyncPinMemoryManager(), manager_b=AsyncToCudaManager()) + def enqueue(self, named_tensors: NamedTensors): TODO From 053a54792810330d3e095d77d5b319db27940f7c Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:14:33 +0800 Subject: [PATCH 0880/1089] more --- python/sglang/srt/model_executor/model_weight_updater.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/model_executor/model_weight_updater.py b/python/sglang/srt/model_executor/model_weight_updater.py index b22118c4ed8..0e5abe9f625 100644 --- a/python/sglang/srt/model_executor/model_weight_updater.py +++ b/python/sglang/srt/model_executor/model_weight_updater.py @@ -1,2 +1,4 @@ class ModelWeightUpdater: - TODO + def __init__(self): + self._manager_transfer_manager = TODO + self._model_weight_source = TODO From 687647c6054deb187f0e46e2d12e1f700bb1fab5 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:15:22 +0800 Subject: [PATCH 0881/1089] more --- .../model_executor/model_weight_updater.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/model_executor/model_weight_updater.py b/python/sglang/srt/model_executor/model_weight_updater.py index 0e5abe9f625..8503402b15d 100644 --- a/python/sglang/srt/model_executor/model_weight_updater.py +++ b/python/sglang/srt/model_executor/model_weight_updater.py @@ -1,4 +1,19 @@ +from abc import ABC + + class ModelWeightUpdater: - def __init__(self): + def __init__(self, init_pin_memory: bool): self._manager_transfer_manager = TODO - self._model_weight_source = TODO + self._model_weight_source = _ModelWeightSourcePinnedMemory() if init_pin_memory else _ModelWeightSourceVanilla() + + +class _ModelWeightSourceBase(ABC): + pass + + +class _ModelWeightSourceVanilla(ABC): + TODO + + +class _ModelWeightSourcePinnedMemory(ABC): + TODO From c40ed8740b89d605b56e47bfb4e94711eac339ed Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:15:49 +0800 Subject: [PATCH 0882/1089] more --- python/sglang/srt/model_executor/model_weight_updater.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/model_executor/model_weight_updater.py b/python/sglang/srt/model_executor/model_weight_updater.py index 8503402b15d..3b4e21a30f2 100644 --- a/python/sglang/srt/model_executor/model_weight_updater.py +++ b/python/sglang/srt/model_executor/model_weight_updater.py @@ -1,4 +1,7 @@ from abc import ABC +from typing import Iterator, Tuple + +import torch class ModelWeightUpdater: @@ -8,7 +11,8 @@ def __init__(self, init_pin_memory: bool): class _ModelWeightSourceBase(ABC): - pass + def get_all_weights(self) -> Iterator[Tuple[str, torch.Tensor]]: + raise NotImplementedError class _ModelWeightSourceVanilla(ABC): From 31c15053205f7bd082015e92cce90f23b1aa3bb2 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:16:11 +0800 Subject: [PATCH 0883/1089] more --- .../sglang/srt/model_executor/model_weight_updater.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/model_executor/model_weight_updater.py b/python/sglang/srt/model_executor/model_weight_updater.py index 3b4e21a30f2..5b9583b8f77 100644 --- a/python/sglang/srt/model_executor/model_weight_updater.py +++ b/python/sglang/srt/model_executor/model_weight_updater.py @@ -15,9 +15,11 @@ def get_all_weights(self) -> Iterator[Tuple[str, torch.Tensor]]: raise NotImplementedError -class _ModelWeightSourceVanilla(ABC): - TODO +class _ModelWeightSourceVanilla(_ModelWeightSourceBase): + def get_all_weights(self) -> Iterator[Tuple[str, torch.Tensor]]: + yield TODO -class _ModelWeightSourcePinnedMemory(ABC): - TODO +class _ModelWeightSourcePinnedMemory(_ModelWeightSourceBase): + def get_all_weights(self) -> Iterator[Tuple[str, torch.Tensor]]: + yield TODO From e7669d0cb9b65fcf844267bbbaca0a7794067b0f Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:16:48 +0800 Subject: [PATCH 0884/1089] more --- python/sglang/srt/model_executor/model_weight_updater.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/model_executor/model_weight_updater.py b/python/sglang/srt/model_executor/model_weight_updater.py index 5b9583b8f77..b53559ef3da 100644 --- a/python/sglang/srt/model_executor/model_weight_updater.py +++ b/python/sglang/srt/model_executor/model_weight_updater.py @@ -2,11 +2,12 @@ from typing import Iterator, Tuple import torch +from sglang.srt.model_executor.memory_transfer import AsyncToCudaManager, CombinedManager class ModelWeightUpdater: def __init__(self, init_pin_memory: bool): - self._manager_transfer_manager = TODO + self._manager_transfer_manager = AsyncToCudaManager() if init_pin_memory else CombinedManager.init_pin_memory_and_to_cuda() self._model_weight_source = _ModelWeightSourcePinnedMemory() if init_pin_memory else _ModelWeightSourceVanilla() From 4c767b3a327bab5cabb8ee8ab18ec5b0f0689038 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:17:23 +0800 Subject: [PATCH 0885/1089] more --- python/sglang/srt/model_executor/expert_location_updater.py | 2 +- python/sglang/srt/model_executor/model_weight_updater.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/model_executor/expert_location_updater.py b/python/sglang/srt/model_executor/expert_location_updater.py index 6d55b2db8eb..77878704903 100644 --- a/python/sglang/srt/model_executor/expert_location_updater.py +++ b/python/sglang/srt/model_executor/expert_location_updater.py @@ -17,7 +17,7 @@ class ExpertLocationUpdater: def __init__(self, model_runner: "ModelRunner"): self._model_runner = model_runner - def prepare(self): + def start_prepare(self): TODO def act(self, recv_req: UpdateExpertLocationReqInput): diff --git a/python/sglang/srt/model_executor/model_weight_updater.py b/python/sglang/srt/model_executor/model_weight_updater.py index b53559ef3da..bbf72314536 100644 --- a/python/sglang/srt/model_executor/model_weight_updater.py +++ b/python/sglang/srt/model_executor/model_weight_updater.py @@ -10,6 +10,9 @@ def __init__(self, init_pin_memory: bool): self._manager_transfer_manager = AsyncToCudaManager() if init_pin_memory else CombinedManager.init_pin_memory_and_to_cuda() self._model_weight_source = _ModelWeightSourcePinnedMemory() if init_pin_memory else _ModelWeightSourceVanilla() + def start_prepare(self): + TODO + class _ModelWeightSourceBase(ABC): def get_all_weights(self) -> Iterator[Tuple[str, torch.Tensor]]: From 384db827c321cbc147fe9cc30e3316f1d1448af9 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:18:24 +0800 Subject: [PATCH 0886/1089] more --- .../model_executor/model_weight_updater.py | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/python/sglang/srt/model_executor/model_weight_updater.py b/python/sglang/srt/model_executor/model_weight_updater.py index bbf72314536..536921af2f7 100644 --- a/python/sglang/srt/model_executor/model_weight_updater.py +++ b/python/sglang/srt/model_executor/model_weight_updater.py @@ -1,4 +1,5 @@ from abc import ABC +from dataclasses import dataclass from typing import Iterator, Tuple import torch @@ -7,6 +8,7 @@ class ModelWeightUpdater: def __init__(self, init_pin_memory: bool): + self._state: _State = _StateIdle() self._manager_transfer_manager = AsyncToCudaManager() if init_pin_memory else CombinedManager.init_pin_memory_and_to_cuda() self._model_weight_source = _ModelWeightSourcePinnedMemory() if init_pin_memory else _ModelWeightSourceVanilla() @@ -14,6 +16,25 @@ def start_prepare(self): TODO +class _State(ABC): + pass + + +@dataclass +class _StateIdle(_State): + pass + + +@dataclass +class _StateAwaitMemoryTransfer(_State): + pass + + +@dataclass +class _StatePrepared(_State): + named_tensors: List[Tuple[str, torch.Tensor]] + + class _ModelWeightSourceBase(ABC): def get_all_weights(self) -> Iterator[Tuple[str, torch.Tensor]]: raise NotImplementedError From b2fc2e49de296ddc1ec805d66cded3c5d577cac2 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:19:17 +0800 Subject: [PATCH 0887/1089] more --- .../srt/model_executor/model_weight_updater.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/model_executor/model_weight_updater.py b/python/sglang/srt/model_executor/model_weight_updater.py index 536921af2f7..b87e5e22f10 100644 --- a/python/sglang/srt/model_executor/model_weight_updater.py +++ b/python/sglang/srt/model_executor/model_weight_updater.py @@ -1,6 +1,6 @@ from abc import ABC from dataclasses import dataclass -from typing import Iterator, Tuple +from typing import Iterator, Tuple, List import torch from sglang.srt.model_executor.memory_transfer import AsyncToCudaManager, CombinedManager @@ -13,8 +13,22 @@ def __init__(self, init_pin_memory: bool): self._model_weight_source = _ModelWeightSourcePinnedMemory() if init_pin_memory else _ModelWeightSourceVanilla() def start_prepare(self): + assert isinstance(self._state, _StateIdle) + + TODO + + self._state = _StateAwaitMemoryTransfer() + + def event_loop_step(self): TODO + def act(self): + assert isinstance(self._state, _StatePrepared) + + TODO + + self._state = _StateIdle() + class _State(ABC): pass From bcb6a0455a17e601498401e71ea6a0aa89bd267d Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:20:22 +0800 Subject: [PATCH 0888/1089] more --- .../sglang/srt/model_executor/model_weight_updater.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/model_executor/model_weight_updater.py b/python/sglang/srt/model_executor/model_weight_updater.py index b87e5e22f10..f73987ee539 100644 --- a/python/sglang/srt/model_executor/model_weight_updater.py +++ b/python/sglang/srt/model_executor/model_weight_updater.py @@ -1,13 +1,15 @@ from abc import ABC from dataclasses import dataclass -from typing import Iterator, Tuple, List +from typing import Iterator, Tuple, List, Callable import torch from sglang.srt.model_executor.memory_transfer import AsyncToCudaManager, CombinedManager class ModelWeightUpdater: - def __init__(self, init_pin_memory: bool): + def __init__(self, init_pin_memory: bool, weight_filter: Callable[[str], bool]): + self._weight_filter = weight_filter + self._state: _State = _StateIdle() self._manager_transfer_manager = AsyncToCudaManager() if init_pin_memory else CombinedManager.init_pin_memory_and_to_cuda() self._model_weight_source = _ModelWeightSourcePinnedMemory() if init_pin_memory else _ModelWeightSourceVanilla() @@ -15,7 +17,8 @@ def __init__(self, init_pin_memory: bool): def start_prepare(self): assert isinstance(self._state, _StateIdle) - TODO + all_weights_iterator = self._model_weight_source.get_all_weights() + interesting_weights = [(name, weight) for name, weight in all_weights_iterator if self._weight_filter(name)] self._state = _StateAwaitMemoryTransfer() @@ -24,7 +27,7 @@ def event_loop_step(self): def act(self): assert isinstance(self._state, _StatePrepared) - + TODO self._state = _StateIdle() From c7f258fd631e45ec60c21aa9063a5e5320327c62 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:20:38 +0800 Subject: [PATCH 0889/1089] more --- python/sglang/srt/model_executor/model_weight_updater.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/model_executor/model_weight_updater.py b/python/sglang/srt/model_executor/model_weight_updater.py index f73987ee539..4acf4dc5f12 100644 --- a/python/sglang/srt/model_executor/model_weight_updater.py +++ b/python/sglang/srt/model_executor/model_weight_updater.py @@ -19,6 +19,7 @@ def start_prepare(self): all_weights_iterator = self._model_weight_source.get_all_weights() interesting_weights = [(name, weight) for name, weight in all_weights_iterator if self._weight_filter(name)] + self._manager_transfer_manager.enqueue(interesting_weights) self._state = _StateAwaitMemoryTransfer() From 8c132b05582a738b45444dd584cbc6d30637ff94 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:21:13 +0800 Subject: [PATCH 0890/1089] more --- python/sglang/srt/model_executor/model_weight_updater.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/model_executor/model_weight_updater.py b/python/sglang/srt/model_executor/model_weight_updater.py index 4acf4dc5f12..751cb251606 100644 --- a/python/sglang/srt/model_executor/model_weight_updater.py +++ b/python/sglang/srt/model_executor/model_weight_updater.py @@ -11,7 +11,7 @@ def __init__(self, init_pin_memory: bool, weight_filter: Callable[[str], bool]): self._weight_filter = weight_filter self._state: _State = _StateIdle() - self._manager_transfer_manager = AsyncToCudaManager() if init_pin_memory else CombinedManager.init_pin_memory_and_to_cuda() + self._memory_transfer_manager = AsyncToCudaManager() if init_pin_memory else CombinedManager.init_pin_memory_and_to_cuda() self._model_weight_source = _ModelWeightSourcePinnedMemory() if init_pin_memory else _ModelWeightSourceVanilla() def start_prepare(self): @@ -19,11 +19,12 @@ def start_prepare(self): all_weights_iterator = self._model_weight_source.get_all_weights() interesting_weights = [(name, weight) for name, weight in all_weights_iterator if self._weight_filter(name)] - self._manager_transfer_manager.enqueue(interesting_weights) + self._memory_transfer_manager.enqueue(interesting_weights) self._state = _StateAwaitMemoryTransfer() def event_loop_step(self): + memory_transfer_outputs = self._memory_transfer_manager.get_outputs() TODO def act(self): From afb7b7b790da9ba8e87d3bd447fc5226d7ccfa94 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:21:30 +0800 Subject: [PATCH 0891/1089] more --- python/sglang/srt/model_executor/model_weight_updater.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/sglang/srt/model_executor/model_weight_updater.py b/python/sglang/srt/model_executor/model_weight_updater.py index 751cb251606..dbc4275c127 100644 --- a/python/sglang/srt/model_executor/model_weight_updater.py +++ b/python/sglang/srt/model_executor/model_weight_updater.py @@ -25,7 +25,12 @@ def start_prepare(self): def event_loop_step(self): memory_transfer_outputs = self._memory_transfer_manager.get_outputs() + assert len(memory_transfer_outputs) in {0, 1} + if len(memory_transfer_outputs) == 0: + return False + TODO + return True def act(self): assert isinstance(self._state, _StatePrepared) From 819508f434a4eb2bd43f693ec3d74d0811db56fb Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:22:16 +0800 Subject: [PATCH 0892/1089] more --- python/sglang/srt/model_executor/model_weight_updater.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/model_executor/model_weight_updater.py b/python/sglang/srt/model_executor/model_weight_updater.py index dbc4275c127..621fee212ae 100644 --- a/python/sglang/srt/model_executor/model_weight_updater.py +++ b/python/sglang/srt/model_executor/model_weight_updater.py @@ -29,7 +29,8 @@ def event_loop_step(self): if len(memory_transfer_outputs) == 0: return False - TODO + memory_transfer_output = memory_transfer_outputs[0] + self._state = _StatePrepared(named_tensors=memory_transfer_output) return True def act(self): From 5b30dfc2402ca982c5321324c09388398d90b2b6 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:23:41 +0800 Subject: [PATCH 0893/1089] more --- python/sglang/srt/model_executor/model_weight_updater.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/model_executor/model_weight_updater.py b/python/sglang/srt/model_executor/model_weight_updater.py index 621fee212ae..915b030fbc9 100644 --- a/python/sglang/srt/model_executor/model_weight_updater.py +++ b/python/sglang/srt/model_executor/model_weight_updater.py @@ -4,6 +4,8 @@ import torch from sglang.srt.model_executor.memory_transfer import AsyncToCudaManager, CombinedManager +from sglang.srt.model_loader.loader import DefaultModelLoader +from sglang.srt.model_loader.utils import set_default_torch_dtype class ModelWeightUpdater: @@ -36,7 +38,9 @@ def event_loop_step(self): def act(self): assert isinstance(self._state, _StatePrepared) - TODO + named_tensors = self._state.named_tensors + with set_default_torch_dtype(TODO): + DefaultModelLoader.load_weights_and_postprocess(model, named_tensors, target_device) self._state = _StateIdle() @@ -62,6 +66,7 @@ class _StatePrepared(_State): class _ModelWeightSourceBase(ABC): def get_all_weights(self) -> Iterator[Tuple[str, torch.Tensor]]: + TODO_with_set_default_torch_dtype raise NotImplementedError From 568b6a751eab3c12077918d0e714f62be1b3b3f4 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:24:39 +0800 Subject: [PATCH 0894/1089] more --- python/sglang/srt/model_executor/model_weight_updater.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/sglang/srt/model_executor/model_weight_updater.py b/python/sglang/srt/model_executor/model_weight_updater.py index 915b030fbc9..49851228110 100644 --- a/python/sglang/srt/model_executor/model_weight_updater.py +++ b/python/sglang/srt/model_executor/model_weight_updater.py @@ -39,6 +39,8 @@ def act(self): assert isinstance(self._state, _StatePrepared) named_tensors = self._state.named_tensors + + # TODO further extract such common operations during weight loading with set_default_torch_dtype(TODO): DefaultModelLoader.load_weights_and_postprocess(model, named_tensors, target_device) From 17c21dbb393cefc24142bedde8de1c934620bd2d Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:25:20 +0800 Subject: [PATCH 0895/1089] more --- .../srt/model_executor/model_weight_updater.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/model_executor/model_weight_updater.py b/python/sglang/srt/model_executor/model_weight_updater.py index 49851228110..063f085a15b 100644 --- a/python/sglang/srt/model_executor/model_weight_updater.py +++ b/python/sglang/srt/model_executor/model_weight_updater.py @@ -4,7 +4,7 @@ import torch from sglang.srt.model_executor.memory_transfer import AsyncToCudaManager, CombinedManager -from sglang.srt.model_loader.loader import DefaultModelLoader +from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader from sglang.srt.model_loader.utils import set_default_torch_dtype @@ -39,7 +39,7 @@ def act(self): assert isinstance(self._state, _StatePrepared) named_tensors = self._state.named_tensors - + # TODO further extract such common operations during weight loading with set_default_torch_dtype(TODO): DefaultModelLoader.load_weights_and_postprocess(model, named_tensors, target_device) @@ -74,7 +74,18 @@ def get_all_weights(self) -> Iterator[Tuple[str, torch.Tensor]]: class _ModelWeightSourceVanilla(_ModelWeightSourceBase): def get_all_weights(self) -> Iterator[Tuple[str, torch.Tensor]]: - yield TODO + loader = get_model_loader(load_config) + assert isinstance(loader, DefaultModelLoader) + + return loader._get_weights_iterator( + DefaultModelLoader.Source( + config.model_path, + revision=config.revision, + fall_back_to_pt=getattr( + model, "fall_back_to_pt_during_load", True + ), + ) + ) class _ModelWeightSourcePinnedMemory(_ModelWeightSourceBase): From 1c1a9ee6f8f1a1116de6dcd1a78dfaea38ef8067 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:25:35 +0800 Subject: [PATCH 0896/1089] more --- .../sglang/srt/model_executor/model_weight_updater.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/python/sglang/srt/model_executor/model_weight_updater.py b/python/sglang/srt/model_executor/model_weight_updater.py index 063f085a15b..c9f2d161438 100644 --- a/python/sglang/srt/model_executor/model_weight_updater.py +++ b/python/sglang/srt/model_executor/model_weight_updater.py @@ -76,16 +76,7 @@ class _ModelWeightSourceVanilla(_ModelWeightSourceBase): def get_all_weights(self) -> Iterator[Tuple[str, torch.Tensor]]: loader = get_model_loader(load_config) assert isinstance(loader, DefaultModelLoader) - - return loader._get_weights_iterator( - DefaultModelLoader.Source( - config.model_path, - revision=config.revision, - fall_back_to_pt=getattr( - model, "fall_back_to_pt_during_load", True - ), - ) - ) + return TODO class _ModelWeightSourcePinnedMemory(_ModelWeightSourceBase): From 6981b8613aa11539e3f969013bdc33e7498883d9 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:28:38 +0800 Subject: [PATCH 0897/1089] more --- python/sglang/srt/model_executor/model_runner.py | 8 +------- python/sglang/srt/model_loader/loader.py | 16 ++++++++++------ 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index d8af8b32688..f4a917116cd 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -519,13 +519,7 @@ def update_weights_from_disk( def get_weight_iter(config): iter = loader._get_weights_iterator( - DefaultModelLoader.Source( - config.model_path, - revision=config.revision, - fall_back_to_pt=getattr( - self.model, "fall_back_to_pt_during_load", True - ), - ) + DefaultModelLoader.Source.init_new(config, model) ) return iter diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py index af0105c7ecf..530277064fd 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py @@ -197,6 +197,15 @@ class Source: fall_back_to_pt: bool = True """Whether .pt weights can be used.""" + @classmethod + def init_new(cls, model_config: ModelConfig, model): + return cls( + model_config.model_path, + model_config.revision, + prefix="", + fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", True), + ) + def __init__(self, load_config: LoadConfig): super().__init__(load_config) if load_config.model_loader_extra_config: @@ -341,12 +350,7 @@ def _get_all_weights( model: nn.Module, ) -> Generator[Tuple[str, torch.Tensor], None, None]: - primary_weights = DefaultModelLoader.Source( - model_config.model_path, - model_config.revision, - prefix="", - fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", True), - ) + primary_weights = DefaultModelLoader.Source.init_new(model_config, model) yield from self._get_weights_iterator(primary_weights) secondary_weights = cast( From 74cb707ccc3203467ceecc6b224fd370563685dc Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:30:49 +0800 Subject: [PATCH 0898/1089] more --- python/sglang/srt/model_executor/model_weight_updater.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/model_executor/model_weight_updater.py b/python/sglang/srt/model_executor/model_weight_updater.py index c9f2d161438..d0edc0d7cf3 100644 --- a/python/sglang/srt/model_executor/model_weight_updater.py +++ b/python/sglang/srt/model_executor/model_weight_updater.py @@ -3,6 +3,7 @@ from typing import Iterator, Tuple, List, Callable import torch +from sglang.srt.configs.load_config import LoadConfig from sglang.srt.model_executor.memory_transfer import AsyncToCudaManager, CombinedManager from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader from sglang.srt.model_loader.utils import set_default_torch_dtype @@ -74,9 +75,11 @@ def get_all_weights(self) -> Iterator[Tuple[str, torch.Tensor]]: class _ModelWeightSourceVanilla(_ModelWeightSourceBase): def get_all_weights(self) -> Iterator[Tuple[str, torch.Tensor]]: + load_config = LoadConfig(load_format=load_format) loader = get_model_loader(load_config) assert isinstance(loader, DefaultModelLoader) - return TODO + with set_default_torch_dtype(model_config.dtype): + return loader._get_weights_iterator(DefaultModelLoader.Source.init_new(model_config, model)) class _ModelWeightSourcePinnedMemory(_ModelWeightSourceBase): From f721eb1a7070a5b353056032076c3a12264c2a55 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:31:14 +0800 Subject: [PATCH 0899/1089] more --- python/sglang/srt/model_executor/model_weight_updater.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/model_executor/model_weight_updater.py b/python/sglang/srt/model_executor/model_weight_updater.py index d0edc0d7cf3..46ed213a076 100644 --- a/python/sglang/srt/model_executor/model_weight_updater.py +++ b/python/sglang/srt/model_executor/model_weight_updater.py @@ -83,5 +83,9 @@ def get_all_weights(self) -> Iterator[Tuple[str, torch.Tensor]]: class _ModelWeightSourcePinnedMemory(_ModelWeightSourceBase): + def __init__(self): + vanilla = _ModelWeightSourceVanilla() + all_weights = list(vanilla.get_all_weights()) + def get_all_weights(self) -> Iterator[Tuple[str, torch.Tensor]]: - yield TODO + return TODO From b33211d565233c484a88297ecf97e2af9c2e1c1b Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:32:03 +0800 Subject: [PATCH 0900/1089] more --- python/sglang/srt/model_executor/model_weight_updater.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/model_executor/model_weight_updater.py b/python/sglang/srt/model_executor/model_weight_updater.py index 46ed213a076..84e6836ce1f 100644 --- a/python/sglang/srt/model_executor/model_weight_updater.py +++ b/python/sglang/srt/model_executor/model_weight_updater.py @@ -1,6 +1,6 @@ from abc import ABC from dataclasses import dataclass -from typing import Iterator, Tuple, List, Callable +from typing import Iterator, Tuple, List, Callable, Iterable import torch from sglang.srt.configs.load_config import LoadConfig @@ -85,7 +85,11 @@ def get_all_weights(self) -> Iterator[Tuple[str, torch.Tensor]]: class _ModelWeightSourcePinnedMemory(_ModelWeightSourceBase): def __init__(self): vanilla = _ModelWeightSourceVanilla() - all_weights = list(vanilla.get_all_weights()) + self._all_weights = _named_tensors_pin_memory(list(vanilla.get_all_weights())) def get_all_weights(self) -> Iterator[Tuple[str, torch.Tensor]]: return TODO + + +def _named_tensors_pin_memory(named_tensors: Iterable[Tuple[str, torch.Tensor]]) -> List[Tuple[str, torch.Tensor]]: + return [(name, tensor.pin_memory()) for name, tensor in named_tensors] From ccc2fa5d7dc706b941abaa35c3b9b815d5e06671 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:32:28 +0800 Subject: [PATCH 0901/1089] more --- .../sglang/srt/model_executor/model_weight_updater.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/model_executor/model_weight_updater.py b/python/sglang/srt/model_executor/model_weight_updater.py index 84e6836ce1f..b701e57b50c 100644 --- a/python/sglang/srt/model_executor/model_weight_updater.py +++ b/python/sglang/srt/model_executor/model_weight_updater.py @@ -1,6 +1,6 @@ from abc import ABC from dataclasses import dataclass -from typing import Iterator, Tuple, List, Callable, Iterable +from typing import Tuple, List, Callable, Iterable import torch from sglang.srt.configs.load_config import LoadConfig @@ -68,13 +68,13 @@ class _StatePrepared(_State): class _ModelWeightSourceBase(ABC): - def get_all_weights(self) -> Iterator[Tuple[str, torch.Tensor]]: + def get_all_weights(self) -> Iterable[Tuple[str, torch.Tensor]]: TODO_with_set_default_torch_dtype raise NotImplementedError class _ModelWeightSourceVanilla(_ModelWeightSourceBase): - def get_all_weights(self) -> Iterator[Tuple[str, torch.Tensor]]: + def get_all_weights(self) -> Iterable[Tuple[str, torch.Tensor]]: load_config = LoadConfig(load_format=load_format) loader = get_model_loader(load_config) assert isinstance(loader, DefaultModelLoader) @@ -87,8 +87,8 @@ def __init__(self): vanilla = _ModelWeightSourceVanilla() self._all_weights = _named_tensors_pin_memory(list(vanilla.get_all_weights())) - def get_all_weights(self) -> Iterator[Tuple[str, torch.Tensor]]: - return TODO + def get_all_weights(self) -> Iterable[Tuple[str, torch.Tensor]]: + return self._all_weights def _named_tensors_pin_memory(named_tensors: Iterable[Tuple[str, torch.Tensor]]) -> List[Tuple[str, torch.Tensor]]: From 0b5efcf217c6efdb3e7c5615ecc8e0b90a86eb13 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:33:26 +0800 Subject: [PATCH 0902/1089] more --- .../srt/model_executor/model_weight_updater.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/model_executor/model_weight_updater.py b/python/sglang/srt/model_executor/model_weight_updater.py index b701e57b50c..6d55811d121 100644 --- a/python/sglang/srt/model_executor/model_weight_updater.py +++ b/python/sglang/srt/model_executor/model_weight_updater.py @@ -4,6 +4,7 @@ import torch from sglang.srt.configs.load_config import LoadConfig +from sglang.srt.configs.model_config import ModelConfig from sglang.srt.model_executor.memory_transfer import AsyncToCudaManager, CombinedManager from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader from sglang.srt.model_loader.utils import set_default_torch_dtype @@ -74,17 +75,22 @@ def get_all_weights(self) -> Iterable[Tuple[str, torch.Tensor]]: class _ModelWeightSourceVanilla(_ModelWeightSourceBase): + def __init__(self, load_format: str, model_config: ModelConfig, model): + self._load_format = load_format + self._model_config = model_config + self._model = model + def get_all_weights(self) -> Iterable[Tuple[str, torch.Tensor]]: - load_config = LoadConfig(load_format=load_format) + load_config = LoadConfig(load_format=self._load_format) loader = get_model_loader(load_config) assert isinstance(loader, DefaultModelLoader) - with set_default_torch_dtype(model_config.dtype): - return loader._get_weights_iterator(DefaultModelLoader.Source.init_new(model_config, model)) + with set_default_torch_dtype(self._model_config.dtype): + return loader._get_weights_iterator(DefaultModelLoader.Source.init_new(self._model_config, self._model)) class _ModelWeightSourcePinnedMemory(_ModelWeightSourceBase): - def __init__(self): - vanilla = _ModelWeightSourceVanilla() + def __init__(self, *args, **kwargs): + vanilla = _ModelWeightSourceVanilla(*args, **kwargs) self._all_weights = _named_tensors_pin_memory(list(vanilla.get_all_weights())) def get_all_weights(self) -> Iterable[Tuple[str, torch.Tensor]]: From bebd2591cbd376fb93c6530f4c0235bce8eaae9c Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:33:58 +0800 Subject: [PATCH 0903/1089] more --- python/sglang/srt/model_executor/model_weight_updater.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/model_executor/model_weight_updater.py b/python/sglang/srt/model_executor/model_weight_updater.py index 6d55811d121..c5bf07ba77d 100644 --- a/python/sglang/srt/model_executor/model_weight_updater.py +++ b/python/sglang/srt/model_executor/model_weight_updater.py @@ -85,7 +85,7 @@ def get_all_weights(self) -> Iterable[Tuple[str, torch.Tensor]]: loader = get_model_loader(load_config) assert isinstance(loader, DefaultModelLoader) with set_default_torch_dtype(self._model_config.dtype): - return loader._get_weights_iterator(DefaultModelLoader.Source.init_new(self._model_config, self._model)) + yield from loader._get_weights_iterator(DefaultModelLoader.Source.init_new(self._model_config, self._model)) class _ModelWeightSourcePinnedMemory(_ModelWeightSourceBase): From 1904401657f008cae213dc58139c160402841c4f Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:34:06 +0800 Subject: [PATCH 0904/1089] more --- python/sglang/srt/model_executor/model_weight_updater.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/sglang/srt/model_executor/model_weight_updater.py b/python/sglang/srt/model_executor/model_weight_updater.py index c5bf07ba77d..537fe34de16 100644 --- a/python/sglang/srt/model_executor/model_weight_updater.py +++ b/python/sglang/srt/model_executor/model_weight_updater.py @@ -70,7 +70,6 @@ class _StatePrepared(_State): class _ModelWeightSourceBase(ABC): def get_all_weights(self) -> Iterable[Tuple[str, torch.Tensor]]: - TODO_with_set_default_torch_dtype raise NotImplementedError From 53f896dfa38deccf6707b96e39dbccdddd344671 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:34:42 +0800 Subject: [PATCH 0905/1089] more --- python/sglang/srt/model_executor/model_weight_updater.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/model_executor/model_weight_updater.py b/python/sglang/srt/model_executor/model_weight_updater.py index 537fe34de16..f418917a0b6 100644 --- a/python/sglang/srt/model_executor/model_weight_updater.py +++ b/python/sglang/srt/model_executor/model_weight_updater.py @@ -14,9 +14,11 @@ class ModelWeightUpdater: def __init__(self, init_pin_memory: bool, weight_filter: Callable[[str], bool]): self._weight_filter = weight_filter - self._state: _State = _StateIdle() + ModelWeightSourceCls = _ModelWeightSourcePinnedMemory if init_pin_memory else _ModelWeightSourceVanilla + self._model_weight_source = ModelWeightSourceCls(load_format=TODO, model_config=TODO, model=TODO) self._memory_transfer_manager = AsyncToCudaManager() if init_pin_memory else CombinedManager.init_pin_memory_and_to_cuda() - self._model_weight_source = _ModelWeightSourcePinnedMemory() if init_pin_memory else _ModelWeightSourceVanilla() + + self._state: _State = _StateIdle() def start_prepare(self): assert isinstance(self._state, _StateIdle) From cf3f695b48137e19d6b7745a311ab7b00fa3b9b3 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:36:10 +0800 Subject: [PATCH 0906/1089] more --- .../model_executor/model_weight_updater.py | 21 +++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/model_executor/model_weight_updater.py b/python/sglang/srt/model_executor/model_weight_updater.py index f418917a0b6..5297344ad79 100644 --- a/python/sglang/srt/model_executor/model_weight_updater.py +++ b/python/sglang/srt/model_executor/model_weight_updater.py @@ -11,11 +11,23 @@ class ModelWeightUpdater: - def __init__(self, init_pin_memory: bool, weight_filter: Callable[[str], bool]): + def __init__( + self, + init_pin_memory: bool, + weight_filter: Callable[[str], bool], + load_format: str, + model_config: ModelConfig, + model, + device, + ): self._weight_filter = weight_filter + self._model_config = model_config + self._model = model + self._device = device ModelWeightSourceCls = _ModelWeightSourcePinnedMemory if init_pin_memory else _ModelWeightSourceVanilla - self._model_weight_source = ModelWeightSourceCls(load_format=TODO, model_config=TODO, model=TODO) + self._model_weight_source = ModelWeightSourceCls(load_format=load_format, model_config=model_config, + model=model) self._memory_transfer_manager = AsyncToCudaManager() if init_pin_memory else CombinedManager.init_pin_memory_and_to_cuda() self._state: _State = _StateIdle() @@ -42,11 +54,12 @@ def event_loop_step(self): def act(self): assert isinstance(self._state, _StatePrepared) + target_device = torch.device(self._device) named_tensors = self._state.named_tensors # TODO further extract such common operations during weight loading - with set_default_torch_dtype(TODO): - DefaultModelLoader.load_weights_and_postprocess(model, named_tensors, target_device) + with set_default_torch_dtype(self._model_config.dtype): + DefaultModelLoader.load_weights_and_postprocess(self._model, named_tensors, target_device) self._state = _StateIdle() From c74310016700399f7e5888d2e3d02ded28f495c8 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:36:55 +0800 Subject: [PATCH 0907/1089] more --- python/sglang/srt/model_executor/model_weight_updater.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/sglang/srt/model_executor/model_weight_updater.py b/python/sglang/srt/model_executor/model_weight_updater.py index 5297344ad79..fed3a30d215 100644 --- a/python/sglang/srt/model_executor/model_weight_updater.py +++ b/python/sglang/srt/model_executor/model_weight_updater.py @@ -42,6 +42,9 @@ def start_prepare(self): self._state = _StateAwaitMemoryTransfer() def event_loop_step(self): + TODO_maybe_rename + TODO_maybe_change_output + memory_transfer_outputs = self._memory_transfer_manager.get_outputs() assert len(memory_transfer_outputs) in {0, 1} if len(memory_transfer_outputs) == 0: From 034c41271abe9ab6e8612439f3bb5b508ed922d8 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:37:11 +0800 Subject: [PATCH 0908/1089] more --- python/sglang/srt/model_executor/model_weight_updater.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/model_executor/model_weight_updater.py b/python/sglang/srt/model_executor/model_weight_updater.py index fed3a30d215..d01ebec0054 100644 --- a/python/sglang/srt/model_executor/model_weight_updater.py +++ b/python/sglang/srt/model_executor/model_weight_updater.py @@ -50,6 +50,7 @@ def event_loop_step(self): if len(memory_transfer_outputs) == 0: return False + assert isinstance(self._state, _StateAwaitMemoryTransfer) memory_transfer_output = memory_transfer_outputs[0] self._state = _StatePrepared(named_tensors=memory_transfer_output) return True From 7ef9b7fd78478e62024c6244fb7bc80f196463ca Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:37:40 +0800 Subject: [PATCH 0909/1089] more --- python/sglang/srt/model_executor/model_weight_updater.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/model_executor/model_weight_updater.py b/python/sglang/srt/model_executor/model_weight_updater.py index d01ebec0054..77dc42b74d2 100644 --- a/python/sglang/srt/model_executor/model_weight_updater.py +++ b/python/sglang/srt/model_executor/model_weight_updater.py @@ -50,10 +50,12 @@ def event_loop_step(self): if len(memory_transfer_outputs) == 0: return False + self._handle_memory_transfer_output(memory_transfer_outputs[0]) + return True + + def _handle_memory_transfer_output(self, memory_transfer_output): assert isinstance(self._state, _StateAwaitMemoryTransfer) - memory_transfer_output = memory_transfer_outputs[0] self._state = _StatePrepared(named_tensors=memory_transfer_output) - return True def act(self): assert isinstance(self._state, _StatePrepared) From 3bf81484aab07a3b0be9381dcb9f847082627bae Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:38:48 +0800 Subject: [PATCH 0910/1089] more --- .../sglang/srt/model_executor/expert_location_updater.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/python/sglang/srt/model_executor/expert_location_updater.py b/python/sglang/srt/model_executor/expert_location_updater.py index 77878704903..83d38d83351 100644 --- a/python/sglang/srt/model_executor/expert_location_updater.py +++ b/python/sglang/srt/model_executor/expert_location_updater.py @@ -5,6 +5,7 @@ from sglang.srt.managers.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.managers.io_struct import UpdateExpertLocationReqInput from sglang.srt.managers.schedule_batch import get_global_expert_location_metadata +from sglang.srt.model_executor.model_weight_updater import ModelWeightUpdater from sglang.srt.utils import get_bool_env_var if TYPE_CHECKING: @@ -16,6 +17,14 @@ class ExpertLocationUpdater: def __init__(self, model_runner: "ModelRunner"): self._model_runner = model_runner + self._model_weight_updater = ModelWeightUpdater( + init_pin_memory=TODO, + weight_filter=self._weight_filter, + load_format=TODO, + model_config=model_runner.model_config, + model=model_runner.model, + device=model_runner.device, + ) def start_prepare(self): TODO From 898b8d8c1c79467250a55016d905c98d5d1fd65e Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:39:00 +0800 Subject: [PATCH 0911/1089] more --- python/sglang/srt/model_executor/expert_location_updater.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/sglang/srt/model_executor/expert_location_updater.py b/python/sglang/srt/model_executor/expert_location_updater.py index 83d38d83351..e811c3aba04 100644 --- a/python/sglang/srt/model_executor/expert_location_updater.py +++ b/python/sglang/srt/model_executor/expert_location_updater.py @@ -52,3 +52,6 @@ def act(self, recv_req: UpdateExpertLocationReqInput): logger.info("update_expert_location end") torch.distributed.barrier() + + def _weight_filter(self, name: str): + return TODO From 998cd4b5d2b8c0c211fca17d09c4f917e7c388eb Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:40:58 +0800 Subject: [PATCH 0912/1089] more --- .../sglang/srt/model_loader/weight_utils.py | 138 +++++++++--------- python/sglang/srt/models/deepseek_v2.py | 2 +- 2 files changed, 72 insertions(+), 68 deletions(-) diff --git a/python/sglang/srt/model_loader/weight_utils.py b/python/sglang/srt/model_loader/weight_utils.py index d1c44e4f7aa..d9857227f3b 100644 --- a/python/sglang/srt/model_loader/weight_utils.py +++ b/python/sglang/srt/model_loader/weight_utils.py @@ -89,8 +89,8 @@ def _shared_pointers(tensors): def convert_bin_to_safetensor_file( - pt_filename: str, - sf_filename: str, + pt_filename: str, + sf_filename: str, ) -> None: loaded = torch.load(pt_filename, map_location="cpu", weights_only=True) if "state_dict" in loaded: @@ -129,9 +129,9 @@ def convert_bin_to_safetensor_file( # TODO(woosuk): Move this to other place. def get_quant_config( - model_config: ModelConfig, - load_config: LoadConfig, - packed_modules_mapping: Dict[str, List[str]], + model_config: ModelConfig, + load_config: LoadConfig, + packed_modules_mapping: Dict[str, List[str]], ) -> QuantizationConfig: quant_cls = get_quantization_config(model_config.quantization) @@ -154,8 +154,8 @@ def get_quant_config( # In case of bitsandbytes/QLoRA, get quant config from the adapter model. if model_config.quantization == "bitsandbytes": if ( - not load_config.model_loader_extra_config - or "qlora_adapter_name_or_path" not in load_config.model_loader_extra_config + not load_config.model_loader_extra_config + or "qlora_adapter_name_or_path" not in load_config.model_loader_extra_config ): return quant_cls.from_config({"adapter_name_or_path": ""}) model_name_or_path = load_config.model_loader_extra_config[ @@ -217,11 +217,11 @@ def get_quant_config( def download_weights_from_hf( - model_name_or_path: str, - cache_dir: Optional[str], - allow_patterns: List[str], - revision: Optional[str] = None, - ignore_patterns: Optional[Union[str, List[str]]] = None, + model_name_or_path: str, + cache_dir: Optional[str], + allow_patterns: List[str], + revision: Optional[str] = None, + ignore_patterns: Optional[Union[str, List[str]]] = None, ) -> str: """Download model weights from Hugging Face Hub. @@ -269,10 +269,10 @@ def download_weights_from_hf( def download_safetensors_index_file_from_hf( - model_name_or_path: str, - index_file: str, - cache_dir: Optional[str], - revision: Optional[str] = None, + model_name_or_path: str, + index_file: str, + cache_dir: Optional[str], + revision: Optional[str] = None, ) -> None: """Download hf safetensors index file from Hugging Face Hub. @@ -308,7 +308,7 @@ def download_safetensors_index_file_from_hf( # So, we use the index_file to # look up which safetensors files should be used. def filter_duplicate_safetensors_files( - hf_weights_files: List[str], hf_folder: str, index_file: str + hf_weights_files: List[str], hf_folder: str, index_file: str ) -> List[str]: # model.safetensors.index.json is a mapping from keys in the # torch state_dict to safetensors file holding that weight. @@ -355,17 +355,17 @@ def filter_files_not_needed_for_inference(hf_weights_files: List[str]) -> List[s def np_cache_weights_iterator( - model_name_or_path: str, - cache_dir: Optional[str], - hf_folder: str, - hf_weights_files: List[str], + model_name_or_path: str, + cache_dir: Optional[str], + hf_folder: str, + hf_weights_files: List[str], ) -> Generator[Tuple[str, torch.Tensor], None, None]: """Iterate over the weights in the model np files. Will dump the model weights to numpy files if they are not already dumped. """ enable_tqdm = ( - not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0 + not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0 ) # Convert the model weights from torch tensors to numpy arrays for # faster loading. @@ -378,10 +378,10 @@ def np_cache_weights_iterator( if not os.path.exists(weight_names_file): weight_names: List[str] = [] for bin_file in tqdm( - hf_weights_files, - desc="Loading np_cache checkpoint shards", - disable=not enable_tqdm, - bar_format=_BAR_FORMAT, + hf_weights_files, + desc="Loading np_cache checkpoint shards", + disable=not enable_tqdm, + bar_format=_BAR_FORMAT, ): state = torch.load(bin_file, map_location="cpu", weights_only=True) for name, param in state.items(): @@ -407,17 +407,17 @@ def decrypt(fn, key): def safetensors_encrypted_weights_iterator( - hf_weights_files: List[str], - is_all_weights_sharded: bool = False, - decryption_key: Optional[str] = None, + hf_weights_files: List[str], + is_all_weights_sharded: bool = False, + decryption_key: Optional[str] = None, ): raise NotImplementedError() def safetensors_weights_iterator( - hf_weights_files: List[str], - is_all_weights_sharded: bool = False, - decryption_key: Optional[str] = None, + hf_weights_files: List[str], + is_all_weights_sharded: bool = False, + decryption_key: Optional[str] = None, ) -> Generator[Tuple[str, torch.Tensor], None, None]: """Iterate over the weights in the model safetensor files. @@ -431,13 +431,13 @@ def safetensors_weights_iterator( return enable_tqdm = ( - not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0 + not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0 ) for st_file in tqdm( - hf_weights_files, - desc="Loading safetensors checkpoint shards", - disable=not enable_tqdm, - bar_format=_BAR_FORMAT, + hf_weights_files, + desc="Loading safetensors checkpoint shards", + disable=not enable_tqdm, + bar_format=_BAR_FORMAT, ): result = safetensors.torch.load_file(st_file, device="cpu") for name, param in result.items(): @@ -445,17 +445,17 @@ def safetensors_weights_iterator( def pt_weights_iterator( - hf_weights_files: List[str], + hf_weights_files: List[str], ) -> Generator[Tuple[str, torch.Tensor], None, None]: """Iterate over the weights in the model bin/pt files.""" enable_tqdm = ( - not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0 + not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0 ) for bin_file in tqdm( - hf_weights_files, - desc="Loading pt checkpoint shards", - disable=not enable_tqdm, - bar_format=_BAR_FORMAT, + hf_weights_files, + desc="Loading pt checkpoint shards", + disable=not enable_tqdm, + bar_format=_BAR_FORMAT, ): state = torch.load(bin_file, map_location="cpu", weights_only=True) yield from state.items() @@ -463,7 +463,7 @@ def pt_weights_iterator( def get_gguf_extra_tensor_names( - gguf_file: str, gguf_to_hf_name_map: Dict[str, str] + gguf_file: str, gguf_to_hf_name_map: Dict[str, str] ) -> List[str]: import gguf @@ -475,7 +475,7 @@ def get_gguf_extra_tensor_names( def gguf_quant_weights_iterator( - gguf_file: str, gguf_to_hf_name_map: Dict[str, str] + gguf_file: str, gguf_to_hf_name_map: Dict[str, str] ) -> Generator[Tuple[str, torch.Tensor], None, None]: """ Iterate over the quant weights in the model gguf files and convert @@ -545,7 +545,7 @@ def default_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> N def row_parallel_weight_loader( - param: torch.Tensor, loaded_weight: torch.Tensor + param: torch.Tensor, loaded_weight: torch.Tensor ) -> None: """Load weights that are row-parallelized.""" tp_rank = get_tensor_model_parallel_rank() @@ -578,7 +578,7 @@ def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: def composed_weight_loader( - loader: LoaderFunction, fn: Callable[[torch.Tensor], torch.Tensor] + loader: LoaderFunction, fn: Callable[[torch.Tensor], torch.Tensor] ) -> LoaderFunction: """Create a weight loader that post-processes the weights after loading""" @@ -591,21 +591,21 @@ def composed_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: def runai_safetensors_weights_iterator( - hf_weights_files: List[str], + hf_weights_files: List[str], ) -> Generator[Tuple[str, torch.Tensor], None, None]: """Iterate over the weights in the model safetensor files.""" from runai_model_streamer import SafetensorsStreamer enable_tqdm = ( - not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0 + not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0 ) with SafetensorsStreamer() as streamer: for st_file in tqdm( - hf_weights_files, - desc="Loading safetensors using Runai Model Streamer", - disable=not enable_tqdm, - bar_format=_BAR_FORMAT, + hf_weights_files, + desc="Loading safetensors using Runai Model Streamer", + disable=not enable_tqdm, + bar_format=_BAR_FORMAT, ): streamer.stream_file(st_file) yield from streamer.get_tensors() @@ -616,14 +616,14 @@ def set_runai_streamer_env(load_config: LoadConfig): extra_config = load_config.model_loader_extra_config if "concurrency" in extra_config and isinstance( - extra_config.get("concurrency"), int + extra_config.get("concurrency"), int ): os.environ["RUNAI_STREAMER_CONCURRENCY"] = str( extra_config.get("concurrency") ) if "memory_limit" in extra_config and isinstance( - extra_config.get("memory_limit"), int + extra_config.get("memory_limit"), int ): os.environ["RUNAI_STREAMER_MEMORY_LIMIT"] = str( extra_config.get("memory_limit") @@ -636,10 +636,10 @@ def set_runai_streamer_env(load_config: LoadConfig): def initialize_dummy_weights( - model: torch.nn.Module, - low: float = -1e-3, - high: float = 1e-3, - seed: int = 1234, + model: torch.nn.Module, + low: float = -1e-3, + high: float = 1e-3, + seed: int = 1234, ) -> None: """Initialize model weights with random values. @@ -710,8 +710,8 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]: if name.endswith(scale_name): # Check and remap the name based on modelopt scale names if any( - modelopt_scale_name in name - for modelopt_scale_name in modelopt_scale_names + modelopt_scale_name in name + for modelopt_scale_name in modelopt_scale_names ): remapped_name = name.replace( f".self_attn.{scale_name[1]}_proj{scale_name}", @@ -768,7 +768,7 @@ def check_tp_ranks(self, info: ValidationInfo) -> "KVCacheQuantSchema": ) for i in range(tp_size): assert ( - i in self.scaling_factor + i in self.scaling_factor ), f"KV cache scales map for TP rank {i} not found." return self @@ -809,11 +809,11 @@ def check_model_type(self, info: ValidationInfo) -> "QuantParamSchema": def kv_cache_scales_loader( - filename: str, - tp_rank: int, - tp_size: int, - num_hidden_layers: int, - model_type: Optional[str], + filename: str, + tp_rank: int, + tp_size: int, + num_hidden_layers: int, + model_type: Optional[str], ) -> Iterable[Tuple[int, float]]: """ A simple utility to read in KV cache scaling factors that have been @@ -849,3 +849,7 @@ def kv_cache_scales_loader( tp_rank, ) return [] + + +class ModelParamNameInfo(ABC): + pass diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 1e9e9a47408..5e2aa2cb6be 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1681,7 +1681,7 @@ def set_embed_and_head(self, embed, head): torch.cuda.empty_cache() torch.cuda.synchronize() - def get_param_category(self, name): + def get_param_name_info(self, name: str) -> ModelParamNameInfo: if ".experts." in name: return "moe" return "others" From f6d0ec42734fb8a67e43ef5213cbfd6db1488862 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:41:23 +0800 Subject: [PATCH 0913/1089] more --- .../sglang/srt/model_loader/weight_utils.py | 149 ++++++++++-------- 1 file changed, 79 insertions(+), 70 deletions(-) diff --git a/python/sglang/srt/model_loader/weight_utils.py b/python/sglang/srt/model_loader/weight_utils.py index d9857227f3b..8c427e56f2a 100644 --- a/python/sglang/srt/model_loader/weight_utils.py +++ b/python/sglang/srt/model_loader/weight_utils.py @@ -9,6 +9,7 @@ import os import tempfile from collections import defaultdict +from dataclasses import dataclass from typing import ( Any, Callable, @@ -28,13 +29,12 @@ import torch from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download from pydantic import BaseModel, ConfigDict, ValidationInfo, model_validator -from tqdm.auto import tqdm - from sglang.srt.configs.load_config import LoadConfig from sglang.srt.configs.model_config import ModelConfig from sglang.srt.distributed import get_tensor_model_parallel_rank from sglang.srt.layers.quantization import QuantizationConfig, get_quantization_config from sglang.srt.utils import print_warning_once +from tqdm.auto import tqdm logger = logging.getLogger(__name__) @@ -89,8 +89,8 @@ def _shared_pointers(tensors): def convert_bin_to_safetensor_file( - pt_filename: str, - sf_filename: str, + pt_filename: str, + sf_filename: str, ) -> None: loaded = torch.load(pt_filename, map_location="cpu", weights_only=True) if "state_dict" in loaded: @@ -129,9 +129,9 @@ def convert_bin_to_safetensor_file( # TODO(woosuk): Move this to other place. def get_quant_config( - model_config: ModelConfig, - load_config: LoadConfig, - packed_modules_mapping: Dict[str, List[str]], + model_config: ModelConfig, + load_config: LoadConfig, + packed_modules_mapping: Dict[str, List[str]], ) -> QuantizationConfig: quant_cls = get_quantization_config(model_config.quantization) @@ -154,8 +154,8 @@ def get_quant_config( # In case of bitsandbytes/QLoRA, get quant config from the adapter model. if model_config.quantization == "bitsandbytes": if ( - not load_config.model_loader_extra_config - or "qlora_adapter_name_or_path" not in load_config.model_loader_extra_config + not load_config.model_loader_extra_config + or "qlora_adapter_name_or_path" not in load_config.model_loader_extra_config ): return quant_cls.from_config({"adapter_name_or_path": ""}) model_name_or_path = load_config.model_loader_extra_config[ @@ -217,11 +217,11 @@ def get_quant_config( def download_weights_from_hf( - model_name_or_path: str, - cache_dir: Optional[str], - allow_patterns: List[str], - revision: Optional[str] = None, - ignore_patterns: Optional[Union[str, List[str]]] = None, + model_name_or_path: str, + cache_dir: Optional[str], + allow_patterns: List[str], + revision: Optional[str] = None, + ignore_patterns: Optional[Union[str, List[str]]] = None, ) -> str: """Download model weights from Hugging Face Hub. @@ -269,10 +269,10 @@ def download_weights_from_hf( def download_safetensors_index_file_from_hf( - model_name_or_path: str, - index_file: str, - cache_dir: Optional[str], - revision: Optional[str] = None, + model_name_or_path: str, + index_file: str, + cache_dir: Optional[str], + revision: Optional[str] = None, ) -> None: """Download hf safetensors index file from Hugging Face Hub. @@ -308,7 +308,7 @@ def download_safetensors_index_file_from_hf( # So, we use the index_file to # look up which safetensors files should be used. def filter_duplicate_safetensors_files( - hf_weights_files: List[str], hf_folder: str, index_file: str + hf_weights_files: List[str], hf_folder: str, index_file: str ) -> List[str]: # model.safetensors.index.json is a mapping from keys in the # torch state_dict to safetensors file holding that weight. @@ -355,17 +355,17 @@ def filter_files_not_needed_for_inference(hf_weights_files: List[str]) -> List[s def np_cache_weights_iterator( - model_name_or_path: str, - cache_dir: Optional[str], - hf_folder: str, - hf_weights_files: List[str], + model_name_or_path: str, + cache_dir: Optional[str], + hf_folder: str, + hf_weights_files: List[str], ) -> Generator[Tuple[str, torch.Tensor], None, None]: """Iterate over the weights in the model np files. Will dump the model weights to numpy files if they are not already dumped. """ enable_tqdm = ( - not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0 + not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0 ) # Convert the model weights from torch tensors to numpy arrays for # faster loading. @@ -378,10 +378,10 @@ def np_cache_weights_iterator( if not os.path.exists(weight_names_file): weight_names: List[str] = [] for bin_file in tqdm( - hf_weights_files, - desc="Loading np_cache checkpoint shards", - disable=not enable_tqdm, - bar_format=_BAR_FORMAT, + hf_weights_files, + desc="Loading np_cache checkpoint shards", + disable=not enable_tqdm, + bar_format=_BAR_FORMAT, ): state = torch.load(bin_file, map_location="cpu", weights_only=True) for name, param in state.items(): @@ -407,17 +407,17 @@ def decrypt(fn, key): def safetensors_encrypted_weights_iterator( - hf_weights_files: List[str], - is_all_weights_sharded: bool = False, - decryption_key: Optional[str] = None, + hf_weights_files: List[str], + is_all_weights_sharded: bool = False, + decryption_key: Optional[str] = None, ): raise NotImplementedError() def safetensors_weights_iterator( - hf_weights_files: List[str], - is_all_weights_sharded: bool = False, - decryption_key: Optional[str] = None, + hf_weights_files: List[str], + is_all_weights_sharded: bool = False, + decryption_key: Optional[str] = None, ) -> Generator[Tuple[str, torch.Tensor], None, None]: """Iterate over the weights in the model safetensor files. @@ -431,13 +431,13 @@ def safetensors_weights_iterator( return enable_tqdm = ( - not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0 + not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0 ) for st_file in tqdm( - hf_weights_files, - desc="Loading safetensors checkpoint shards", - disable=not enable_tqdm, - bar_format=_BAR_FORMAT, + hf_weights_files, + desc="Loading safetensors checkpoint shards", + disable=not enable_tqdm, + bar_format=_BAR_FORMAT, ): result = safetensors.torch.load_file(st_file, device="cpu") for name, param in result.items(): @@ -445,17 +445,17 @@ def safetensors_weights_iterator( def pt_weights_iterator( - hf_weights_files: List[str], + hf_weights_files: List[str], ) -> Generator[Tuple[str, torch.Tensor], None, None]: """Iterate over the weights in the model bin/pt files.""" enable_tqdm = ( - not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0 + not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0 ) for bin_file in tqdm( - hf_weights_files, - desc="Loading pt checkpoint shards", - disable=not enable_tqdm, - bar_format=_BAR_FORMAT, + hf_weights_files, + desc="Loading pt checkpoint shards", + disable=not enable_tqdm, + bar_format=_BAR_FORMAT, ): state = torch.load(bin_file, map_location="cpu", weights_only=True) yield from state.items() @@ -463,7 +463,7 @@ def pt_weights_iterator( def get_gguf_extra_tensor_names( - gguf_file: str, gguf_to_hf_name_map: Dict[str, str] + gguf_file: str, gguf_to_hf_name_map: Dict[str, str] ) -> List[str]: import gguf @@ -475,7 +475,7 @@ def get_gguf_extra_tensor_names( def gguf_quant_weights_iterator( - gguf_file: str, gguf_to_hf_name_map: Dict[str, str] + gguf_file: str, gguf_to_hf_name_map: Dict[str, str] ) -> Generator[Tuple[str, torch.Tensor], None, None]: """ Iterate over the quant weights in the model gguf files and convert @@ -545,7 +545,7 @@ def default_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> N def row_parallel_weight_loader( - param: torch.Tensor, loaded_weight: torch.Tensor + param: torch.Tensor, loaded_weight: torch.Tensor ) -> None: """Load weights that are row-parallelized.""" tp_rank = get_tensor_model_parallel_rank() @@ -578,7 +578,7 @@ def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: def composed_weight_loader( - loader: LoaderFunction, fn: Callable[[torch.Tensor], torch.Tensor] + loader: LoaderFunction, fn: Callable[[torch.Tensor], torch.Tensor] ) -> LoaderFunction: """Create a weight loader that post-processes the weights after loading""" @@ -591,21 +591,21 @@ def composed_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: def runai_safetensors_weights_iterator( - hf_weights_files: List[str], + hf_weights_files: List[str], ) -> Generator[Tuple[str, torch.Tensor], None, None]: """Iterate over the weights in the model safetensor files.""" from runai_model_streamer import SafetensorsStreamer enable_tqdm = ( - not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0 + not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0 ) with SafetensorsStreamer() as streamer: for st_file in tqdm( - hf_weights_files, - desc="Loading safetensors using Runai Model Streamer", - disable=not enable_tqdm, - bar_format=_BAR_FORMAT, + hf_weights_files, + desc="Loading safetensors using Runai Model Streamer", + disable=not enable_tqdm, + bar_format=_BAR_FORMAT, ): streamer.stream_file(st_file) yield from streamer.get_tensors() @@ -616,14 +616,14 @@ def set_runai_streamer_env(load_config: LoadConfig): extra_config = load_config.model_loader_extra_config if "concurrency" in extra_config and isinstance( - extra_config.get("concurrency"), int + extra_config.get("concurrency"), int ): os.environ["RUNAI_STREAMER_CONCURRENCY"] = str( extra_config.get("concurrency") ) if "memory_limit" in extra_config and isinstance( - extra_config.get("memory_limit"), int + extra_config.get("memory_limit"), int ): os.environ["RUNAI_STREAMER_MEMORY_LIMIT"] = str( extra_config.get("memory_limit") @@ -636,10 +636,10 @@ def set_runai_streamer_env(load_config: LoadConfig): def initialize_dummy_weights( - model: torch.nn.Module, - low: float = -1e-3, - high: float = 1e-3, - seed: int = 1234, + model: torch.nn.Module, + low: float = -1e-3, + high: float = 1e-3, + seed: int = 1234, ) -> None: """Initialize model weights with random values. @@ -710,8 +710,8 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]: if name.endswith(scale_name): # Check and remap the name based on modelopt scale names if any( - modelopt_scale_name in name - for modelopt_scale_name in modelopt_scale_names + modelopt_scale_name in name + for modelopt_scale_name in modelopt_scale_names ): remapped_name = name.replace( f".self_attn.{scale_name[1]}_proj{scale_name}", @@ -768,7 +768,7 @@ def check_tp_ranks(self, info: ValidationInfo) -> "KVCacheQuantSchema": ) for i in range(tp_size): assert ( - i in self.scaling_factor + i in self.scaling_factor ), f"KV cache scales map for TP rank {i} not found." return self @@ -809,11 +809,11 @@ def check_model_type(self, info: ValidationInfo) -> "QuantParamSchema": def kv_cache_scales_loader( - filename: str, - tp_rank: int, - tp_size: int, - num_hidden_layers: int, - model_type: Optional[str], + filename: str, + tp_rank: int, + tp_size: int, + num_hidden_layers: int, + model_type: Optional[str], ) -> Iterable[Tuple[int, float]]: """ A simple utility to read in KV cache scaling factors that have been @@ -851,5 +851,14 @@ def kv_cache_scales_loader( return [] -class ModelParamNameInfo(ABC): +@dataclass +class ModelParamNameInfoMoe: pass + + +@dataclass +class ModelParamNameInfoOthers: + pass + + +ModelParamNameInfo = Union[ModelParamNameInfoMoe, ModelParamNameInfoOthers] From 388ce3136fececa97c8952883150d24f92941e5b Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:41:46 +0800 Subject: [PATCH 0914/1089] more --- python/sglang/srt/model_executor/model_runner.py | 2 +- python/sglang/srt/model_loader/weight_utils.py | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index f4a917116cd..0d2d891d5a5 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -528,7 +528,7 @@ def filter_weight_iter(iter: Iterable[Tuple[str, torch.Tensor]]): yield from iter else: for name, weight in iter: - if self.model.get_param_category(name) in param_categories: + if self.model.get_param_name_info(name).category in param_categories: yield name, weight def model_load_weights(model, iter): diff --git a/python/sglang/srt/model_loader/weight_utils.py b/python/sglang/srt/model_loader/weight_utils.py index 8c427e56f2a..bf803f73431 100644 --- a/python/sglang/srt/model_loader/weight_utils.py +++ b/python/sglang/srt/model_loader/weight_utils.py @@ -853,12 +853,16 @@ def kv_cache_scales_loader( @dataclass class ModelParamNameInfoMoe: - pass + @property + def category(self): + return "moe" @dataclass class ModelParamNameInfoOthers: - pass + @property + def category(self): + return "others" ModelParamNameInfo = Union[ModelParamNameInfoMoe, ModelParamNameInfoOthers] From efdb31814ead2e490cbd6e506098f71029d28d72 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:42:12 +0800 Subject: [PATCH 0915/1089] more --- python/sglang/srt/models/deepseek_v2.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 5e2aa2cb6be..3a9ef62f2aa 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -78,7 +78,8 @@ global_server_args_dict, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode -from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.model_loader.weight_utils import default_weight_loader, ModelParamNameInfo, ModelParamNameInfoMoe, \ + ModelParamNameInfoOthers from sglang.srt.utils import DeepEPMode, add_prefix, is_cuda, is_hip from torch import nn from tqdm import tqdm @@ -1683,8 +1684,8 @@ def set_embed_and_head(self, embed, head): def get_param_name_info(self, name: str) -> ModelParamNameInfo: if ".experts." in name: - return "moe" - return "others" + return ModelParamNameInfoMoe(TODO) + return ModelParamNameInfoOthers() @classmethod def get_model_config_for_expert_location(cls, config): From 1baa588e291df704e33c864caae864708e24513f Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:42:37 +0800 Subject: [PATCH 0916/1089] more --- python/sglang/srt/model_loader/weight_utils.py | 3 +++ python/sglang/srt/models/deepseek_v2.py | 5 ++++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/model_loader/weight_utils.py b/python/sglang/srt/model_loader/weight_utils.py index bf803f73431..e86001b41bd 100644 --- a/python/sglang/srt/model_loader/weight_utils.py +++ b/python/sglang/srt/model_loader/weight_utils.py @@ -853,6 +853,9 @@ def kv_cache_scales_loader( @dataclass class ModelParamNameInfoMoe: + layer_id: int + expert_id: int + @property def category(self): return "moe" diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 3a9ef62f2aa..9a2dad37c30 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1684,7 +1684,10 @@ def set_embed_and_head(self, embed, head): def get_param_name_info(self, name: str) -> ModelParamNameInfo: if ".experts." in name: - return ModelParamNameInfoMoe(TODO) + return ModelParamNameInfoMoe( + layer_id=TODO, + expert_id=TODO, + ) return ModelParamNameInfoOthers() @classmethod From e54fe0507aa3ddebcb7e65a6f44d024d4c11c652 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:43:02 +0800 Subject: [PATCH 0917/1089] more --- python/sglang/srt/models/deepseek_v2.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 9a2dad37c30..fc7ea7cbe99 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -18,6 +18,7 @@ import logging import os +import re from dataclasses import dataclass from enum import Enum, auto from typing import Any, Dict, Iterable, Optional, Tuple @@ -1685,8 +1686,8 @@ def set_embed_and_head(self, embed, head): def get_param_name_info(self, name: str) -> ModelParamNameInfo: if ".experts." in name: return ModelParamNameInfoMoe( - layer_id=TODO, - expert_id=TODO, + layer_id=int(re.search(r"layers\.(\d+)", name).group(1)), + expert_id=int(re.search(r"experts\.(\d+)", name).group(1)), ) return ModelParamNameInfoOthers() From ac0058ed82979c00fb467fe478ab3d997c3d5878 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:43:50 +0800 Subject: [PATCH 0918/1089] more --- python/sglang/srt/model_executor/expert_location_updater.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/sglang/srt/model_executor/expert_location_updater.py b/python/sglang/srt/model_executor/expert_location_updater.py index e811c3aba04..59cdcfc15b2 100644 --- a/python/sglang/srt/model_executor/expert_location_updater.py +++ b/python/sglang/srt/model_executor/expert_location_updater.py @@ -6,6 +6,7 @@ from sglang.srt.managers.io_struct import UpdateExpertLocationReqInput from sglang.srt.managers.schedule_batch import get_global_expert_location_metadata from sglang.srt.model_executor.model_weight_updater import ModelWeightUpdater +from sglang.srt.model_loader.weight_utils import ModelParamNameInfo, ModelParamNameInfoMoe from sglang.srt.utils import get_bool_env_var if TYPE_CHECKING: @@ -54,4 +55,8 @@ def act(self, recv_req: UpdateExpertLocationReqInput): torch.distributed.barrier() def _weight_filter(self, name: str): + info: ModelParamNameInfo = self._model_runner.model.get_param_name_info() + if not isinstance(info, ModelParamNameInfoMoe): + return False + return TODO From f80431bf5483b6a3a7548adc18074a5ca39d66c0 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:44:33 +0800 Subject: [PATCH 0919/1089] more --- python/sglang/srt/model_executor/expert_location_updater.py | 5 +++-- python/sglang/srt/model_executor/model_weight_updater.py | 6 ++---- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/model_executor/expert_location_updater.py b/python/sglang/srt/model_executor/expert_location_updater.py index 59cdcfc15b2..202e4a5530f 100644 --- a/python/sglang/srt/model_executor/expert_location_updater.py +++ b/python/sglang/srt/model_executor/expert_location_updater.py @@ -20,7 +20,6 @@ def __init__(self, model_runner: "ModelRunner"): self._model_runner = model_runner self._model_weight_updater = ModelWeightUpdater( init_pin_memory=TODO, - weight_filter=self._weight_filter, load_format=TODO, model_config=model_runner.model_config, model=model_runner.model, @@ -28,7 +27,9 @@ def __init__(self, model_runner: "ModelRunner"): ) def start_prepare(self): - TODO + self._model_weight_updater.start_prepare( + weight_filter=lambda name: self._weight_filter(name), + ) def act(self, recv_req: UpdateExpertLocationReqInput): logger.info("update_expert_location start") diff --git a/python/sglang/srt/model_executor/model_weight_updater.py b/python/sglang/srt/model_executor/model_weight_updater.py index 77dc42b74d2..3999eda074c 100644 --- a/python/sglang/srt/model_executor/model_weight_updater.py +++ b/python/sglang/srt/model_executor/model_weight_updater.py @@ -14,13 +14,11 @@ class ModelWeightUpdater: def __init__( self, init_pin_memory: bool, - weight_filter: Callable[[str], bool], load_format: str, model_config: ModelConfig, model, device, ): - self._weight_filter = weight_filter self._model_config = model_config self._model = model self._device = device @@ -32,11 +30,11 @@ def __init__( self._state: _State = _StateIdle() - def start_prepare(self): + def start_prepare(self, weight_filter: Callable[[str], bool]): assert isinstance(self._state, _StateIdle) all_weights_iterator = self._model_weight_source.get_all_weights() - interesting_weights = [(name, weight) for name, weight in all_weights_iterator if self._weight_filter(name)] + interesting_weights = [(name, weight) for name, weight in all_weights_iterator if weight_filter(name)] self._memory_transfer_manager.enqueue(interesting_weights) self._state = _StateAwaitMemoryTransfer() From 6cfae60e0bc9b70e7de89316f44f93ca5fa66ade Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:45:23 +0800 Subject: [PATCH 0920/1089] more --- .../srt/model_executor/expert_location_updater.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/model_executor/expert_location_updater.py b/python/sglang/srt/model_executor/expert_location_updater.py index 202e4a5530f..29869441f4e 100644 --- a/python/sglang/srt/model_executor/expert_location_updater.py +++ b/python/sglang/srt/model_executor/expert_location_updater.py @@ -1,8 +1,9 @@ import logging -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Dict, List import torch from sglang.srt.managers.expert_distribution import get_global_expert_distribution_recorder +from sglang.srt.managers.expert_location import ExpertLocationMetadata from sglang.srt.managers.io_struct import UpdateExpertLocationReqInput from sglang.srt.managers.schedule_batch import get_global_expert_location_metadata from sglang.srt.model_executor.model_weight_updater import ModelWeightUpdater @@ -61,3 +62,10 @@ def _weight_filter(self, name: str): return False return TODO + + +def _compute_interesting_logical_experts_of_layer( + old_expert_location_metadata: ExpertLocationMetadata, + new_expert_location_metadata: ExpertLocationMetadata, +) -> Dict[int, List[int]]: + return TODO From bcf6ddb28020acacab54a88018941f6cab0fa8c3 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:46:28 +0800 Subject: [PATCH 0921/1089] more --- .../model_executor/expert_location_updater.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/model_executor/expert_location_updater.py b/python/sglang/srt/model_executor/expert_location_updater.py index 29869441f4e..eebf9c7e8db 100644 --- a/python/sglang/srt/model_executor/expert_location_updater.py +++ b/python/sglang/srt/model_executor/expert_location_updater.py @@ -28,8 +28,13 @@ def __init__(self, model_runner: "ModelRunner"): ) def start_prepare(self): + interesting_logical_experts_of_layer = _compute_interesting_logical_experts_of_layer( + old_expert_location_metadata=TODO, + new_expert_location_metadata=TODO, + ) + self._model_weight_updater.start_prepare( - weight_filter=lambda name: self._weight_filter(name), + weight_filter=lambda name: self._weight_filter(name, interesting_logical_experts_of_layer), ) def act(self, recv_req: UpdateExpertLocationReqInput): @@ -56,12 +61,12 @@ def act(self, recv_req: UpdateExpertLocationReqInput): logger.info("update_expert_location end") torch.distributed.barrier() - def _weight_filter(self, name: str): - info: ModelParamNameInfo = self._model_runner.model.get_param_name_info() - if not isinstance(info, ModelParamNameInfoMoe): - return False - - return TODO + def _weight_filter(self, name: str, interesting_logical_experts_of_layer: Dict[int, List[int]]): + info: ModelParamNameInfo = self._model_runner.model.get_param_name_info(name) + return ( + isinstance(info, ModelParamNameInfoMoe) + and (info.expert_id in interesting_logical_experts_of_layer[info.layer_id]) + ) def _compute_interesting_logical_experts_of_layer( From c3f483dccc9d437ada9696dacfd8cfb8f6182563 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:48:07 +0800 Subject: [PATCH 0922/1089] more --- python/sglang/srt/model_executor/expert_location_updater.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/model_executor/expert_location_updater.py b/python/sglang/srt/model_executor/expert_location_updater.py index eebf9c7e8db..3bfc9700c3f 100644 --- a/python/sglang/srt/model_executor/expert_location_updater.py +++ b/python/sglang/srt/model_executor/expert_location_updater.py @@ -73,4 +73,7 @@ def _compute_interesting_logical_experts_of_layer( old_expert_location_metadata: ExpertLocationMetadata, new_expert_location_metadata: ExpertLocationMetadata, ) -> Dict[int, List[int]]: - return TODO + interesting_logical_experts_of_layer = {} + for layer_id in range(old_expert_location_metadata.num_layers): + interesting_logical_experts_of_layer[layer_id] = TODO + return interesting_logical_experts_of_layer From 562d4dbea1a4f6e621a63115567ecc2f7132433d Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:49:29 +0800 Subject: [PATCH 0923/1089] more --- .../srt/model_executor/expert_location_updater.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/model_executor/expert_location_updater.py b/python/sglang/srt/model_executor/expert_location_updater.py index 3bfc9700c3f..1c23ba631ca 100644 --- a/python/sglang/srt/model_executor/expert_location_updater.py +++ b/python/sglang/srt/model_executor/expert_location_updater.py @@ -72,8 +72,18 @@ def _weight_filter(self, name: str, interesting_logical_experts_of_layer: Dict[i def _compute_interesting_logical_experts_of_layer( old_expert_location_metadata: ExpertLocationMetadata, new_expert_location_metadata: ExpertLocationMetadata, + ep_rank: int, ) -> Dict[int, List[int]]: + num_layers = old_expert_location_metadata.num_layers + num_local_physical_experts = old_expert_location_metadata.num_local_physical_experts + + def _get_partial_physical_to_logical_map(meta: ExpertLocationMetadata, layer_id: int): + return meta.physical_to_logical_map[layer_id, + num_local_physical_experts * ep_rank: num_local_physical_experts * (ep_rank + 1)] + interesting_logical_experts_of_layer = {} - for layer_id in range(old_expert_location_metadata.num_layers): + for layer_id in range(num_layers): + old_partial_map = _get_partial_physical_to_logical_map(old_expert_location_metadata) + new_partial_map = _get_partial_physical_to_logical_map(new_expert_location_metadata) interesting_logical_experts_of_layer[layer_id] = TODO return interesting_logical_experts_of_layer From c14830371f259272379ceceb30d7fc91d9b21249 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:49:47 +0800 Subject: [PATCH 0924/1089] more --- python/sglang/srt/model_executor/expert_location_updater.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/model_executor/expert_location_updater.py b/python/sglang/srt/model_executor/expert_location_updater.py index 1c23ba631ca..ab8d4c487bb 100644 --- a/python/sglang/srt/model_executor/expert_location_updater.py +++ b/python/sglang/srt/model_executor/expert_location_updater.py @@ -85,5 +85,5 @@ def _get_partial_physical_to_logical_map(meta: ExpertLocationMetadata, layer_id: for layer_id in range(num_layers): old_partial_map = _get_partial_physical_to_logical_map(old_expert_location_metadata) new_partial_map = _get_partial_physical_to_logical_map(new_expert_location_metadata) - interesting_logical_experts_of_layer[layer_id] = TODO + interesting_logical_experts_of_layer[layer_id] = new_partial_map[new_partial_map != old_partial_map].tolist() return interesting_logical_experts_of_layer From af35f39cc50b8b13006700602b61d7b3196367c3 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:49:59 +0800 Subject: [PATCH 0925/1089] more --- python/sglang/srt/model_executor/expert_location_updater.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/model_executor/expert_location_updater.py b/python/sglang/srt/model_executor/expert_location_updater.py index ab8d4c487bb..1f306ddb719 100644 --- a/python/sglang/srt/model_executor/expert_location_updater.py +++ b/python/sglang/srt/model_executor/expert_location_updater.py @@ -83,7 +83,7 @@ def _get_partial_physical_to_logical_map(meta: ExpertLocationMetadata, layer_id: interesting_logical_experts_of_layer = {} for layer_id in range(num_layers): - old_partial_map = _get_partial_physical_to_logical_map(old_expert_location_metadata) - new_partial_map = _get_partial_physical_to_logical_map(new_expert_location_metadata) + old_partial_map = _get_partial_physical_to_logical_map(old_expert_location_metadata, layer_id) + new_partial_map = _get_partial_physical_to_logical_map(new_expert_location_metadata, layer_id) interesting_logical_experts_of_layer[layer_id] = new_partial_map[new_partial_map != old_partial_map].tolist() return interesting_logical_experts_of_layer From ed853342d565afe2d464cc9cefa7432a7e509d8d Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:50:32 +0800 Subject: [PATCH 0926/1089] more --- python/sglang/srt/model_executor/expert_location_updater.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/model_executor/expert_location_updater.py b/python/sglang/srt/model_executor/expert_location_updater.py index 1f306ddb719..96b2bab24ba 100644 --- a/python/sglang/srt/model_executor/expert_location_updater.py +++ b/python/sglang/srt/model_executor/expert_location_updater.py @@ -27,10 +27,11 @@ def __init__(self, model_runner: "ModelRunner"): device=model_runner.device, ) - def start_prepare(self): + def start_prepare(self, expert_location_metadata: ExpertLocationMetadata): interesting_logical_experts_of_layer = _compute_interesting_logical_experts_of_layer( old_expert_location_metadata=TODO, - new_expert_location_metadata=TODO, + new_expert_location_metadata=expert_location_metadata, + ep_rank=self._model_runner.tp_rank, ) self._model_weight_updater.start_prepare( From 51a279fb6dd9aa5914f29ff91c07059539216444 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:50:56 +0800 Subject: [PATCH 0927/1089] more --- python/sglang/srt/model_executor/expert_location_updater.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/model_executor/expert_location_updater.py b/python/sglang/srt/model_executor/expert_location_updater.py index 96b2bab24ba..24c259d2ea0 100644 --- a/python/sglang/srt/model_executor/expert_location_updater.py +++ b/python/sglang/srt/model_executor/expert_location_updater.py @@ -29,7 +29,7 @@ def __init__(self, model_runner: "ModelRunner"): def start_prepare(self, expert_location_metadata: ExpertLocationMetadata): interesting_logical_experts_of_layer = _compute_interesting_logical_experts_of_layer( - old_expert_location_metadata=TODO, + old_expert_location_metadata=get_global_expert_location_metadata(), new_expert_location_metadata=expert_location_metadata, ep_rank=self._model_runner.tp_rank, ) From c7290594392fe399aebc71e9be91cb4ee154ce1c Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:51:31 +0800 Subject: [PATCH 0928/1089] more --- python/sglang/srt/model_executor/expert_location_updater.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/python/sglang/srt/model_executor/expert_location_updater.py b/python/sglang/srt/model_executor/expert_location_updater.py index 24c259d2ea0..3d7153df3ca 100644 --- a/python/sglang/srt/model_executor/expert_location_updater.py +++ b/python/sglang/srt/model_executor/expert_location_updater.py @@ -38,6 +38,12 @@ def start_prepare(self, expert_location_metadata: ExpertLocationMetadata): weight_filter=lambda name: self._weight_filter(name, interesting_logical_experts_of_layer), ) + def event_loop_step(self): + TODO_maybe_rename + TODO_maybe_change_output + self._model_weight_updater.event_loop_step() + return TODO + def act(self, recv_req: UpdateExpertLocationReqInput): logger.info("update_expert_location start") torch.distributed.barrier() From e1375048de9531e024736e318b39f2641068c12c Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:52:12 +0800 Subject: [PATCH 0929/1089] more --- python/sglang/srt/model_executor/memory_transfer.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/sglang/srt/model_executor/memory_transfer.py b/python/sglang/srt/model_executor/memory_transfer.py index d17dabcf224..a9f2f044974 100644 --- a/python/sglang/srt/model_executor/memory_transfer.py +++ b/python/sglang/srt/model_executor/memory_transfer.py @@ -14,6 +14,10 @@ def get_outputs(self) -> List[NamedTensors]: class CombinedManager(TensorOperationManagerBase): + def __init__(self, manager_a: TensorOperationManagerBase, manager_b: TensorOperationManagerBase): + self._manager_a = manager_a + self._manager_b = manager_b + @classmethod def init_pin_memory_and_to_cuda(cls): return cls(manager_a=AsyncPinMemoryManager(), manager_b=AsyncToCudaManager()) From 523956765a0a101a7d4018672609458315da3ff3 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:52:31 +0800 Subject: [PATCH 0930/1089] more --- python/sglang/srt/model_executor/memory_transfer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/model_executor/memory_transfer.py b/python/sglang/srt/model_executor/memory_transfer.py index a9f2f044974..8082099c215 100644 --- a/python/sglang/srt/model_executor/memory_transfer.py +++ b/python/sglang/srt/model_executor/memory_transfer.py @@ -15,6 +15,7 @@ def get_outputs(self) -> List[NamedTensors]: class CombinedManager(TensorOperationManagerBase): def __init__(self, manager_a: TensorOperationManagerBase, manager_b: TensorOperationManagerBase): + # For simplicity, only support chaining 2 managers, but can be extended to N self._manager_a = manager_a self._manager_b = manager_b From b7905c3873a712d497613dad8017c1022d84c021 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:52:41 +0800 Subject: [PATCH 0931/1089] more --- python/sglang/srt/model_executor/memory_transfer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/model_executor/memory_transfer.py b/python/sglang/srt/model_executor/memory_transfer.py index 8082099c215..e20f02a52c7 100644 --- a/python/sglang/srt/model_executor/memory_transfer.py +++ b/python/sglang/srt/model_executor/memory_transfer.py @@ -24,7 +24,7 @@ def init_pin_memory_and_to_cuda(cls): return cls(manager_a=AsyncPinMemoryManager(), manager_b=AsyncToCudaManager()) def enqueue(self, named_tensors: NamedTensors): - TODO + self._manager_a.enqueue(named_tensors) def get_outputs(self) -> List[NamedTensors]: return TODO From c7869358d96060034324a30fc980981f02c02070 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:53:07 +0800 Subject: [PATCH 0932/1089] more --- python/sglang/srt/model_executor/memory_transfer.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/model_executor/memory_transfer.py b/python/sglang/srt/model_executor/memory_transfer.py index e20f02a52c7..94e620f57a4 100644 --- a/python/sglang/srt/model_executor/memory_transfer.py +++ b/python/sglang/srt/model_executor/memory_transfer.py @@ -27,7 +27,11 @@ def enqueue(self, named_tensors: NamedTensors): self._manager_a.enqueue(named_tensors) def get_outputs(self) -> List[NamedTensors]: - return TODO + outputs_a = self._manager_a.get_outputs() + for output_a in outputs_a: + self._manager_b.enqueue(output_a) + + return self._manager_b.get_outputs() class AsyncPinMemoryManager(TensorOperationManagerBase): From c89ba31eebae4044c75a6601741722c3be204997 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:53:46 +0800 Subject: [PATCH 0933/1089] more --- python/sglang/srt/model_executor/memory_transfer.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/python/sglang/srt/model_executor/memory_transfer.py b/python/sglang/srt/model_executor/memory_transfer.py index 94e620f57a4..f413e79e2da 100644 --- a/python/sglang/srt/model_executor/memory_transfer.py +++ b/python/sglang/srt/model_executor/memory_transfer.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass from typing import List, Tuple import torch @@ -48,3 +49,9 @@ def enqueue(self, named_tensors: NamedTensors): def get_outputs(self) -> List[NamedTensors]: return TODO + + +@dataclass +class _AsyncToCudaTask: + event: torch.cuda.Event + output_named_tensors: NamedTensors From 8ded71425547a7ffeccbe37df7efc1e901ebca60 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:54:16 +0800 Subject: [PATCH 0934/1089] more --- python/sglang/srt/model_executor/memory_transfer.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/model_executor/memory_transfer.py b/python/sglang/srt/model_executor/memory_transfer.py index f413e79e2da..d0b62580b95 100644 --- a/python/sglang/srt/model_executor/memory_transfer.py +++ b/python/sglang/srt/model_executor/memory_transfer.py @@ -44,8 +44,15 @@ def get_outputs(self) -> List[NamedTensors]: class AsyncToCudaManager(TensorOperationManagerBase): + def __init__(self): + self._queue = [] + def enqueue(self, named_tensors: NamedTensors): - TODO + self._queue.append(_AsyncToCudaTask( + event=event, + input_named_tensors=named_tensors, + output_named_tensors=output_named_tensors, + )) def get_outputs(self) -> List[NamedTensors]: return TODO @@ -54,4 +61,5 @@ def get_outputs(self) -> List[NamedTensors]: @dataclass class _AsyncToCudaTask: event: torch.cuda.Event + input_named_tensors: NamedTensors output_named_tensors: NamedTensors From debef349abe6fa26eab10c9185384181c9da5f5f Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:55:03 +0800 Subject: [PATCH 0935/1089] more --- python/sglang/srt/model_executor/memory_transfer.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/model_executor/memory_transfer.py b/python/sglang/srt/model_executor/memory_transfer.py index d0b62580b95..ce2069ef79e 100644 --- a/python/sglang/srt/model_executor/memory_transfer.py +++ b/python/sglang/srt/model_executor/memory_transfer.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import List, Tuple +from typing import List, Tuple, Optional import torch @@ -46,8 +46,10 @@ def get_outputs(self) -> List[NamedTensors]: class AsyncToCudaManager(TensorOperationManagerBase): def __init__(self): self._queue = [] + self._alt_stream: Optional[torch.cuda.Stream] = None def enqueue(self, named_tensors: NamedTensors): + self._auto_create_stream() self._queue.append(_AsyncToCudaTask( event=event, input_named_tensors=named_tensors, @@ -57,6 +59,10 @@ def enqueue(self, named_tensors: NamedTensors): def get_outputs(self) -> List[NamedTensors]: return TODO + def _auto_create_stream(self): + if self._alt_stream is None: + self._alt_stream = torch.cuda.Stream() + @dataclass class _AsyncToCudaTask: From 7c91cf27f56fb2d1b9695b0266fa2c58857a787e Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:56:03 +0800 Subject: [PATCH 0936/1089] more --- .../sglang/srt/model_executor/memory_transfer.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/model_executor/memory_transfer.py b/python/sglang/srt/model_executor/memory_transfer.py index ce2069ef79e..94899a99170 100644 --- a/python/sglang/srt/model_executor/memory_transfer.py +++ b/python/sglang/srt/model_executor/memory_transfer.py @@ -50,8 +50,18 @@ def __init__(self): def enqueue(self, named_tensors: NamedTensors): self._auto_create_stream() + + self._alt_stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(self._alt_stream): + output_named_tensors = [ + (name, tensor.to("cuda", non_blocking=True)) + for name, tensor in named_tensors + ] + finish_event = torch.cuda.Event() + finish_event.record() + self._queue.append(_AsyncToCudaTask( - event=event, + finish_event=finish_event, input_named_tensors=named_tensors, output_named_tensors=output_named_tensors, )) @@ -66,6 +76,6 @@ def _auto_create_stream(self): @dataclass class _AsyncToCudaTask: - event: torch.cuda.Event + finish_event: torch.cuda.Event input_named_tensors: NamedTensors output_named_tensors: NamedTensors From 36d2f1b0555a67f7771ed8ccf82133081b3a7d8a Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:56:29 +0800 Subject: [PATCH 0937/1089] more --- python/sglang/srt/model_executor/memory_transfer.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/model_executor/memory_transfer.py b/python/sglang/srt/model_executor/memory_transfer.py index 94899a99170..ac4c14197bd 100644 --- a/python/sglang/srt/model_executor/memory_transfer.py +++ b/python/sglang/srt/model_executor/memory_transfer.py @@ -45,7 +45,7 @@ def get_outputs(self) -> List[NamedTensors]: class AsyncToCudaManager(TensorOperationManagerBase): def __init__(self): - self._queue = [] + self._inflight_tasks = [] self._alt_stream: Optional[torch.cuda.Stream] = None def enqueue(self, named_tensors: NamedTensors): @@ -60,14 +60,17 @@ def enqueue(self, named_tensors: NamedTensors): finish_event = torch.cuda.Event() finish_event.record() - self._queue.append(_AsyncToCudaTask( + self._inflight_tasks.append(_AsyncToCudaTask( finish_event=finish_event, input_named_tensors=named_tensors, output_named_tensors=output_named_tensors, )) def get_outputs(self) -> List[NamedTensors]: - return TODO + outputs = [] + while len(self._inflight_tasks) > 0: + TODO + return outputs def _auto_create_stream(self): if self._alt_stream is None: From 5e5680d931007c54de2ea5f198c7008f253e184f Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:56:48 +0800 Subject: [PATCH 0938/1089] more --- python/sglang/srt/model_executor/memory_transfer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/model_executor/memory_transfer.py b/python/sglang/srt/model_executor/memory_transfer.py index ac4c14197bd..ccb72f5f685 100644 --- a/python/sglang/srt/model_executor/memory_transfer.py +++ b/python/sglang/srt/model_executor/memory_transfer.py @@ -45,7 +45,7 @@ def get_outputs(self) -> List[NamedTensors]: class AsyncToCudaManager(TensorOperationManagerBase): def __init__(self): - self._inflight_tasks = [] + self._inflight_tasks: List[_AsyncToCudaTask] = [] self._alt_stream: Optional[torch.cuda.Stream] = None def enqueue(self, named_tensors: NamedTensors): @@ -68,7 +68,7 @@ def enqueue(self, named_tensors: NamedTensors): def get_outputs(self) -> List[NamedTensors]: outputs = [] - while len(self._inflight_tasks) > 0: + while len(self._inflight_tasks) > 0 and self._inflight_tasks[0].finish_event.query(): TODO return outputs From 485b8dd0daae8adc563c609513325700980ecbd0 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:57:10 +0800 Subject: [PATCH 0939/1089] more --- python/sglang/srt/model_executor/memory_transfer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/model_executor/memory_transfer.py b/python/sglang/srt/model_executor/memory_transfer.py index ccb72f5f685..e02a09a7a1c 100644 --- a/python/sglang/srt/model_executor/memory_transfer.py +++ b/python/sglang/srt/model_executor/memory_transfer.py @@ -69,7 +69,8 @@ def enqueue(self, named_tensors: NamedTensors): def get_outputs(self) -> List[NamedTensors]: outputs = [] while len(self._inflight_tasks) > 0 and self._inflight_tasks[0].finish_event.query(): - TODO + task = self._inflight_tasks.pop(0) + outputs.append(task.output_named_tensors) return outputs def _auto_create_stream(self): From 9f40a96fd196a4c02c2a44127b9622ab50c80f9e Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:57:56 +0800 Subject: [PATCH 0940/1089] more --- python/sglang/srt/model_executor/memory_transfer.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/model_executor/memory_transfer.py b/python/sglang/srt/model_executor/memory_transfer.py index e02a09a7a1c..fb9fd7662b2 100644 --- a/python/sglang/srt/model_executor/memory_transfer.py +++ b/python/sglang/srt/model_executor/memory_transfer.py @@ -70,9 +70,13 @@ def get_outputs(self) -> List[NamedTensors]: outputs = [] while len(self._inflight_tasks) > 0 and self._inflight_tasks[0].finish_event.query(): task = self._inflight_tasks.pop(0) - outputs.append(task.output_named_tensors) + outputs.append(self._handle_one_output(task)) return outputs + def _handle_one_output(self, task: "_AsyncToCudaTask"): + torch.cuda.current_stream().wait_stream(self._alt_stream) + return task.output_named_tensors + def _auto_create_stream(self): if self._alt_stream is None: self._alt_stream = torch.cuda.Stream() From 9df33270767a6f621a0a4c014a647cdad4787038 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:58:44 +0800 Subject: [PATCH 0941/1089] more --- python/sglang/srt/model_executor/memory_transfer.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/sglang/srt/model_executor/memory_transfer.py b/python/sglang/srt/model_executor/memory_transfer.py index fb9fd7662b2..1d062546fa6 100644 --- a/python/sglang/srt/model_executor/memory_transfer.py +++ b/python/sglang/srt/model_executor/memory_transfer.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from queue import SimpleQueue from typing import List, Tuple, Optional import torch @@ -36,6 +37,10 @@ def get_outputs(self) -> List[NamedTensors]: class AsyncPinMemoryManager(TensorOperationManagerBase): + def __init__(self): + self._input_queue = SimpleQueue() + self._output_queue = SimpleQueue() + def enqueue(self, named_tensors: NamedTensors): TODO From 2a3680af509381aaac3d59417da2c13fd1abd001 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 11:59:03 +0800 Subject: [PATCH 0942/1089] more --- python/sglang/srt/model_executor/memory_transfer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/model_executor/memory_transfer.py b/python/sglang/srt/model_executor/memory_transfer.py index 1d062546fa6..f363a10c8cd 100644 --- a/python/sglang/srt/model_executor/memory_transfer.py +++ b/python/sglang/srt/model_executor/memory_transfer.py @@ -7,6 +7,7 @@ NamedTensors = List[Tuple[str, torch.Tensor]] +# For simplicity, classes here does not have tagging etc class TensorOperationManagerBase: def enqueue(self, named_tensors: NamedTensors): raise NotImplementedError From e053933cd1169fc0935672a626140301b522dfca Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 12:00:05 +0800 Subject: [PATCH 0943/1089] more --- python/sglang/srt/model_executor/memory_transfer.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/model_executor/memory_transfer.py b/python/sglang/srt/model_executor/memory_transfer.py index f363a10c8cd..07aa933359d 100644 --- a/python/sglang/srt/model_executor/memory_transfer.py +++ b/python/sglang/srt/model_executor/memory_transfer.py @@ -43,10 +43,17 @@ def __init__(self): self._output_queue = SimpleQueue() def enqueue(self, named_tensors: NamedTensors): - TODO + self._auto_create_background_thread() + self._input_queue.put_nowait(named_tensors) def get_outputs(self) -> List[NamedTensors]: - return TODO + outputs = [] + while True: + outputs.append(self._output_queue.get_nowait()) + return outputs + + def _auto_create_background_thread(self): + TODO class AsyncToCudaManager(TensorOperationManagerBase): From 1da663de884ce2df0f1fceeed4f076deb7126336 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 12:00:27 +0800 Subject: [PATCH 0944/1089] more --- python/sglang/srt/model_executor/memory_transfer.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/model_executor/memory_transfer.py b/python/sglang/srt/model_executor/memory_transfer.py index 07aa933359d..79712a1ea03 100644 --- a/python/sglang/srt/model_executor/memory_transfer.py +++ b/python/sglang/srt/model_executor/memory_transfer.py @@ -1,3 +1,4 @@ +import queue from dataclasses import dataclass from queue import SimpleQueue from typing import List, Tuple, Optional @@ -49,7 +50,10 @@ def enqueue(self, named_tensors: NamedTensors): def get_outputs(self) -> List[NamedTensors]: outputs = [] while True: - outputs.append(self._output_queue.get_nowait()) + try: + outputs.append(self._output_queue.get_nowait()) + except queue.Empty: + break return outputs def _auto_create_background_thread(self): From 0dbfdd8e7d1c9009efa7d79f41708fc671a12976 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 12:01:21 +0800 Subject: [PATCH 0945/1089] more --- python/sglang/srt/model_executor/memory_transfer.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/python/sglang/srt/model_executor/memory_transfer.py b/python/sglang/srt/model_executor/memory_transfer.py index 79712a1ea03..084b87f3834 100644 --- a/python/sglang/srt/model_executor/memory_transfer.py +++ b/python/sglang/srt/model_executor/memory_transfer.py @@ -1,6 +1,7 @@ import queue from dataclasses import dataclass from queue import SimpleQueue +from threading import Thread from typing import List, Tuple, Optional import torch @@ -42,6 +43,7 @@ class AsyncPinMemoryManager(TensorOperationManagerBase): def __init__(self): self._input_queue = SimpleQueue() self._output_queue = SimpleQueue() + self._background_thread = None def enqueue(self, named_tensors: NamedTensors): self._auto_create_background_thread() @@ -57,6 +59,13 @@ def get_outputs(self) -> List[NamedTensors]: return outputs def _auto_create_background_thread(self): + if self._background_thread is not None: + return + + self._background_thread = Thread(target=self._background_thread_entrypoint) + self._background_thread.start() + + def _background_thread_entrypoint(self): TODO From 0b7069b752ec0c7d0eed4f2968042ea345065548 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 12:03:07 +0800 Subject: [PATCH 0946/1089] more --- .../sglang/srt/model_executor/memory_transfer.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/model_executor/memory_transfer.py b/python/sglang/srt/model_executor/memory_transfer.py index 084b87f3834..789ea9e854f 100644 --- a/python/sglang/srt/model_executor/memory_transfer.py +++ b/python/sglang/srt/model_executor/memory_transfer.py @@ -1,4 +1,6 @@ +import logging import queue +import traceback from dataclasses import dataclass from queue import SimpleQueue from threading import Thread @@ -8,6 +10,8 @@ NamedTensors = List[Tuple[str, torch.Tensor]] +logger = logging.getLogger(__name__) + # For simplicity, classes here does not have tagging etc class TensorOperationManagerBase: @@ -66,7 +70,15 @@ def _auto_create_background_thread(self): self._background_thread.start() def _background_thread_entrypoint(self): - TODO + try: + while True: + input_data = self._input_queue.get() + output_data = [(name, tensor.pin_memory()) for name, tensor in input_data] + self._output_queue.put(output_data) + except Exception as e: + logger.warning(f"AsyncPinMemoryManager background thread error {e}") + traceback.print_exc() + raise class AsyncToCudaManager(TensorOperationManagerBase): From 8358d022efe8b01065e8092b8946cbb68b947089 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 12:03:48 +0800 Subject: [PATCH 0947/1089] more --- python/sglang/srt/model_executor/memory_transfer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/model_executor/memory_transfer.py b/python/sglang/srt/model_executor/memory_transfer.py index 789ea9e854f..777c606d4a1 100644 --- a/python/sglang/srt/model_executor/memory_transfer.py +++ b/python/sglang/srt/model_executor/memory_transfer.py @@ -81,6 +81,7 @@ def _background_thread_entrypoint(self): raise +# Can use cuMemCreate etc if we want to further remove a GPU->GPU copy class AsyncToCudaManager(TensorOperationManagerBase): def __init__(self): self._inflight_tasks: List[_AsyncToCudaTask] = [] From 28d3d82b3e69c384976958fa6af2a3f6cb52cff8 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 12:04:26 +0800 Subject: [PATCH 0948/1089] more --- python/sglang/srt/poll_based_barrier.py | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 python/sglang/srt/poll_based_barrier.py diff --git a/python/sglang/srt/poll_based_barrier.py b/python/sglang/srt/poll_based_barrier.py new file mode 100644 index 00000000000..e55cbaf1eab --- /dev/null +++ b/python/sglang/srt/poll_based_barrier.py @@ -0,0 +1,2 @@ +class PollBasedBarrier: + TODO From beeb011120b2ac797916b1b9bb9f7682c0659bcc Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 12:05:01 +0800 Subject: [PATCH 0949/1089] more --- python/sglang/srt/poll_based_barrier.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/poll_based_barrier.py b/python/sglang/srt/poll_based_barrier.py index e55cbaf1eab..865ec21213a 100644 --- a/python/sglang/srt/poll_based_barrier.py +++ b/python/sglang/srt/poll_based_barrier.py @@ -1,2 +1,6 @@ class PollBasedBarrier: - TODO + def local_arrive(self): + TODO + + def poll_global_arrive(self) -> bool: + TODO From e585d3f953a597cc2c6543763c62baac149879ad Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 12:05:56 +0800 Subject: [PATCH 0950/1089] more --- python/sglang/srt/managers/scheduler_input_blocker.py | 2 ++ python/sglang/srt/poll_based_barrier.py | 3 +++ 2 files changed, 5 insertions(+) diff --git a/python/sglang/srt/managers/scheduler_input_blocker.py b/python/sglang/srt/managers/scheduler_input_blocker.py index 6668d96b950..9adf99bf057 100644 --- a/python/sglang/srt/managers/scheduler_input_blocker.py +++ b/python/sglang/srt/managers/scheduler_input_blocker.py @@ -18,6 +18,7 @@ from sglang import ServerArgs from sglang.srt.managers.io_struct import BlockReqInput, BlockReqType +from sglang.srt.poll_based_barrier import PollBasedBarrier class SchedulerInputBlocker: @@ -25,6 +26,7 @@ def __init__(self, server_args: ServerArgs, noop: bool): self._state = _State.UNBLOCKED self._pending_reqs = [] self._noop = noop + self._global_unblock_barrier = PollBasedBarrier(noop=noop) assert ( server_args.disable_overlap_schedule ), "SchedulerInputBlocker requires overlap scheduler to be disabled" diff --git a/python/sglang/srt/poll_based_barrier.py b/python/sglang/srt/poll_based_barrier.py index 865ec21213a..1ef758c279b 100644 --- a/python/sglang/srt/poll_based_barrier.py +++ b/python/sglang/srt/poll_based_barrier.py @@ -1,4 +1,7 @@ class PollBasedBarrier: + def __init__(self, noop: bool = False): + self._noop = noop + def local_arrive(self): TODO From 5ff279dfc1905678b48827c19bc5d77a0bcda558 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 12:08:04 +0800 Subject: [PATCH 0951/1089] more --- .../srt/managers/scheduler_input_blocker.py | 7 ------- python/sglang/srt/poll_based_barrier.py | 17 +++++++++++++++-- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/managers/scheduler_input_blocker.py b/python/sglang/srt/managers/scheduler_input_blocker.py index 9adf99bf057..54046de9479 100644 --- a/python/sglang/srt/managers/scheduler_input_blocker.py +++ b/python/sglang/srt/managers/scheduler_input_blocker.py @@ -74,13 +74,6 @@ def _execute_unblock_req(self): original=_State.BLOCKED, target=_State.GLOBAL_UNBLOCK_BARRIER ) - def _compute_global_unblock_barrier(self): - local_arrived = self._noop or (self._state == _State.GLOBAL_UNBLOCK_BARRIER) - global_arrived = torch.tensor(local_arrived).cuda() - # Can optimize if bottleneck - torch.distributed.all_reduce(global_arrived, torch.distributed.ReduceOp.MIN) - return global_arrived.cpu().item() - def _handle_arrive_unblock_barrier(self): self._change_state( original=_State.GLOBAL_UNBLOCK_BARRIER, target=_State.UNBLOCKED diff --git a/python/sglang/srt/poll_based_barrier.py b/python/sglang/srt/poll_based_barrier.py index 1ef758c279b..1aa6bbd1edd 100644 --- a/python/sglang/srt/poll_based_barrier.py +++ b/python/sglang/srt/poll_based_barrier.py @@ -1,9 +1,22 @@ +import torch + + class PollBasedBarrier: def __init__(self, noop: bool = False): self._noop = noop + self._local_arrived = False def local_arrive(self): - TODO + assert not self._local_arrived + self._local_arrived = True def poll_global_arrive(self) -> bool: - TODO + local_arrived = self._noop or self._local_arrived + global_arrived = torch.tensor(local_arrived).cuda() + # Can optimize if bottleneck + torch.distributed.all_reduce(global_arrived, torch.distributed.ReduceOp.MIN) + return global_arrived.cpu().item() + + def _change_state(self, original: "_State", target: "_State"): + assert self._state == original, f"{self._state=} {original=} {target=}" + self._state = target From e5c79370f06f6dc68f3c6cf507f77ad07d7371cf Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 12:08:48 +0800 Subject: [PATCH 0952/1089] more --- python/sglang/srt/poll_based_barrier.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/poll_based_barrier.py b/python/sglang/srt/poll_based_barrier.py index 1aa6bbd1edd..f7e090e0533 100644 --- a/python/sglang/srt/poll_based_barrier.py +++ b/python/sglang/srt/poll_based_barrier.py @@ -10,7 +10,11 @@ def local_arrive(self): assert not self._local_arrived self._local_arrived = True - def poll_global_arrive(self) -> bool: + def poll_global_arrived(self) -> bool: + global_arrived = self._compute_global_arrived() + TODo + + def _compute_global_arrived(self) -> bool: local_arrived = self._noop or self._local_arrived global_arrived = torch.tensor(local_arrived).cuda() # Can optimize if bottleneck From 47134c29343c7e9dce0f6c77f5b18038f64e3438 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 12:09:24 +0800 Subject: [PATCH 0953/1089] more --- python/sglang/srt/poll_based_barrier.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/poll_based_barrier.py b/python/sglang/srt/poll_based_barrier.py index f7e090e0533..e06176781d9 100644 --- a/python/sglang/srt/poll_based_barrier.py +++ b/python/sglang/srt/poll_based_barrier.py @@ -12,8 +12,11 @@ def local_arrive(self): def poll_global_arrived(self) -> bool: global_arrived = self._compute_global_arrived() - TODo - + output = self._local_arrived and global_arrived + if output: + self._local_arrived = False + return output + def _compute_global_arrived(self) -> bool: local_arrived = self._noop or self._local_arrived global_arrived = torch.tensor(local_arrived).cuda() From 614236a4ce4b98c99571d30fb4396129521df108 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 12:10:11 +0800 Subject: [PATCH 0954/1089] more --- python/sglang/srt/managers/scheduler_input_blocker.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/managers/scheduler_input_blocker.py b/python/sglang/srt/managers/scheduler_input_blocker.py index 54046de9479..f4ef15f3eae 100644 --- a/python/sglang/srt/managers/scheduler_input_blocker.py +++ b/python/sglang/srt/managers/scheduler_input_blocker.py @@ -73,6 +73,7 @@ def _execute_unblock_req(self): self._change_state( original=_State.BLOCKED, target=_State.GLOBAL_UNBLOCK_BARRIER ) + self._global_unblock_barrier.local_arrive() def _handle_arrive_unblock_barrier(self): self._change_state( From 12cdd94b72eb149e30d72fe3741074ea42b17fe0 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 12:10:45 +0800 Subject: [PATCH 0955/1089] more --- python/sglang/srt/managers/scheduler_input_blocker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/scheduler_input_blocker.py b/python/sglang/srt/managers/scheduler_input_blocker.py index f4ef15f3eae..30a26e3cc1e 100644 --- a/python/sglang/srt/managers/scheduler_input_blocker.py +++ b/python/sglang/srt/managers/scheduler_input_blocker.py @@ -39,7 +39,7 @@ def handle(self, recv_reqs: Optional[List[Any]]): for recv_req in recv_reqs: output_reqs += self._handle_recv_req(recv_req) - global_arrived_unblock_barrier = self._compute_global_unblock_barrier() + global_arrived_unblock_barrier = self._global_unblock_barrier.poll_global_arrived() if ( self._state == _State.GLOBAL_UNBLOCK_BARRIER and global_arrived_unblock_barrier From 5586e15823168b83be22f6264da3c1be01860095 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 12:12:14 +0800 Subject: [PATCH 0956/1089] more --- python/sglang/srt/poll_based_barrier.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/python/sglang/srt/poll_based_barrier.py b/python/sglang/srt/poll_based_barrier.py index e06176781d9..40d7fbe82c8 100644 --- a/python/sglang/srt/poll_based_barrier.py +++ b/python/sglang/srt/poll_based_barrier.py @@ -23,7 +23,3 @@ def _compute_global_arrived(self) -> bool: # Can optimize if bottleneck torch.distributed.all_reduce(global_arrived, torch.distributed.ReduceOp.MIN) return global_arrived.cpu().item() - - def _change_state(self, original: "_State", target: "_State"): - assert self._state == original, f"{self._state=} {original=} {target=}" - self._state = target From 28686bbb88bb30c0a812d7ec5700fcbe53964592 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 12:14:56 +0800 Subject: [PATCH 0957/1089] more --- .../model_executor/expert_location_updater.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/python/sglang/srt/model_executor/expert_location_updater.py b/python/sglang/srt/model_executor/expert_location_updater.py index 3d7153df3ca..8c046c00681 100644 --- a/python/sglang/srt/model_executor/expert_location_updater.py +++ b/python/sglang/srt/model_executor/expert_location_updater.py @@ -52,18 +52,13 @@ def act(self, recv_req: UpdateExpertLocationReqInput): get_global_expert_location_metadata().update(recv_req.expert_location_metadata) if self._model_runner.tp_rank == 0 and get_bool_env_var( - "SGLANG_LOG_EXPERT_LOCATION_METADATA" + "SGLANG_LOG_EXPERT_LOCATION_METADATA" ): logger.info( f"Updated expert_location_metadata: {get_global_expert_location_metadata().debug_str()}" ) - # We may be able to further reduce lock time by faster copying, pre-transfering, etc - self._model_runner.update_weights_from_disk( - model_path=self._model_runner.model_config.model_path, - load_format=self._model_runner.server_args.load_format, - param_categories=["moe"], - ) + self._model_weight_updater.act() logger.info("update_expert_location end") torch.distributed.barrier() @@ -71,15 +66,15 @@ def act(self, recv_req: UpdateExpertLocationReqInput): def _weight_filter(self, name: str, interesting_logical_experts_of_layer: Dict[int, List[int]]): info: ModelParamNameInfo = self._model_runner.model.get_param_name_info(name) return ( - isinstance(info, ModelParamNameInfoMoe) - and (info.expert_id in interesting_logical_experts_of_layer[info.layer_id]) + isinstance(info, ModelParamNameInfoMoe) + and (info.expert_id in interesting_logical_experts_of_layer[info.layer_id]) ) def _compute_interesting_logical_experts_of_layer( - old_expert_location_metadata: ExpertLocationMetadata, - new_expert_location_metadata: ExpertLocationMetadata, - ep_rank: int, + old_expert_location_metadata: ExpertLocationMetadata, + new_expert_location_metadata: ExpertLocationMetadata, + ep_rank: int, ) -> Dict[int, List[int]]: num_layers = old_expert_location_metadata.num_layers num_local_physical_experts = old_expert_location_metadata.num_local_physical_experts From de7789625c92e6112b00ad79524e09fe63bc3252 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 12:15:16 +0800 Subject: [PATCH 0958/1089] more --- python/sglang/srt/model_executor/expert_location_updater.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/sglang/srt/model_executor/expert_location_updater.py b/python/sglang/srt/model_executor/expert_location_updater.py index 8c046c00681..5b6c7faf264 100644 --- a/python/sglang/srt/model_executor/expert_location_updater.py +++ b/python/sglang/srt/model_executor/expert_location_updater.py @@ -45,7 +45,6 @@ def event_loop_step(self): return TODO def act(self, recv_req: UpdateExpertLocationReqInput): - logger.info("update_expert_location start") torch.distributed.barrier() get_global_expert_distribution_recorder().flush_buffer_depending_on_expert_location_metadata() @@ -60,7 +59,6 @@ def act(self, recv_req: UpdateExpertLocationReqInput): self._model_weight_updater.act() - logger.info("update_expert_location end") torch.distributed.barrier() def _weight_filter(self, name: str, interesting_logical_experts_of_layer: Dict[int, List[int]]): From 41d2bf68aba2a0579311ea8626437598c486cd9e Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 12:17:25 +0800 Subject: [PATCH 0959/1089] more --- .../sglang/srt/model_executor/expert_location_updater.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/model_executor/expert_location_updater.py b/python/sglang/srt/model_executor/expert_location_updater.py index 5b6c7faf264..74a8c5a558e 100644 --- a/python/sglang/srt/model_executor/expert_location_updater.py +++ b/python/sglang/srt/model_executor/expert_location_updater.py @@ -8,6 +8,7 @@ from sglang.srt.managers.schedule_batch import get_global_expert_location_metadata from sglang.srt.model_executor.model_weight_updater import ModelWeightUpdater from sglang.srt.model_loader.weight_utils import ModelParamNameInfo, ModelParamNameInfoMoe +from sglang.srt.poll_based_barrier import PollBasedBarrier from sglang.srt.utils import get_bool_env_var if TYPE_CHECKING: @@ -26,6 +27,7 @@ def __init__(self, model_runner: "ModelRunner"): model=model_runner.model, device=model_runner.device, ) + self._prepare_end_barrier = PollBasedBarrier(noop=False) def start_prepare(self, expert_location_metadata: ExpertLocationMetadata): interesting_logical_experts_of_layer = _compute_interesting_logical_experts_of_layer( @@ -41,8 +43,11 @@ def start_prepare(self, expert_location_metadata: ExpertLocationMetadata): def event_loop_step(self): TODO_maybe_rename TODO_maybe_change_output - self._model_weight_updater.event_loop_step() - return TODO + prepare_done = self._model_weight_updater.event_loop_step() + if prepare_done: + self._prepare_end_barrier.local_arrive() + + TODO_outer_call_this def act(self, recv_req: UpdateExpertLocationReqInput): torch.distributed.barrier() From 084b195bc2c1d2e6904d3ae6d5d89694794052e1 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 12:17:44 +0800 Subject: [PATCH 0960/1089] more --- .../sglang/srt/model_executor/expert_location_updater.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/model_executor/expert_location_updater.py b/python/sglang/srt/model_executor/expert_location_updater.py index 74a8c5a558e..ef3ded0633a 100644 --- a/python/sglang/srt/model_executor/expert_location_updater.py +++ b/python/sglang/srt/model_executor/expert_location_updater.py @@ -47,14 +47,17 @@ def event_loop_step(self): if prepare_done: self._prepare_end_barrier.local_arrive() + if self._prepare_end_barrier.poll_global_arrived(): + self.act() + TODO_outer_call_this - def act(self, recv_req: UpdateExpertLocationReqInput): + def act(self): torch.distributed.barrier() get_global_expert_distribution_recorder().flush_buffer_depending_on_expert_location_metadata() - get_global_expert_location_metadata().update(recv_req.expert_location_metadata) + get_global_expert_location_metadata().update(expert_location_metadata) if self._model_runner.tp_rank == 0 and get_bool_env_var( "SGLANG_LOG_EXPERT_LOCATION_METADATA" ): From 80463a81fa23012edd954edf7c2f593279cdd4b9 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 12:18:05 +0800 Subject: [PATCH 0961/1089] more --- .../sglang/srt/model_executor/expert_location_updater.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/python/sglang/srt/model_executor/expert_location_updater.py b/python/sglang/srt/model_executor/expert_location_updater.py index ef3ded0633a..71533ed5105 100644 --- a/python/sglang/srt/model_executor/expert_location_updater.py +++ b/python/sglang/srt/model_executor/expert_location_updater.py @@ -41,17 +41,12 @@ def start_prepare(self, expert_location_metadata: ExpertLocationMetadata): ) def event_loop_step(self): - TODO_maybe_rename - TODO_maybe_change_output - prepare_done = self._model_weight_updater.event_loop_step() - if prepare_done: + if self._model_weight_updater.event_loop_step(): self._prepare_end_barrier.local_arrive() if self._prepare_end_barrier.poll_global_arrived(): self.act() - TODO_outer_call_this - def act(self): torch.distributed.barrier() From 2b39cfe6fc14c828fdff89146fde74f5226650ca Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 12:18:31 +0800 Subject: [PATCH 0962/1089] more --- python/sglang/srt/model_executor/expert_location_updater.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/model_executor/expert_location_updater.py b/python/sglang/srt/model_executor/expert_location_updater.py index 71533ed5105..ef48d93a9df 100644 --- a/python/sglang/srt/model_executor/expert_location_updater.py +++ b/python/sglang/srt/model_executor/expert_location_updater.py @@ -45,9 +45,9 @@ def event_loop_step(self): self._prepare_end_barrier.local_arrive() if self._prepare_end_barrier.poll_global_arrived(): - self.act() + self._act() - def act(self): + def _act(self): torch.distributed.barrier() get_global_expert_distribution_recorder().flush_buffer_depending_on_expert_location_metadata() From f92ab32166eac824a31ce9cc1f8441a717e29ce6 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 12:19:20 +0800 Subject: [PATCH 0963/1089] more --- .../srt/model_executor/expert_location_updater.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/model_executor/expert_location_updater.py b/python/sglang/srt/model_executor/expert_location_updater.py index ef48d93a9df..244c1cfde78 100644 --- a/python/sglang/srt/model_executor/expert_location_updater.py +++ b/python/sglang/srt/model_executor/expert_location_updater.py @@ -1,5 +1,5 @@ import logging -from typing import TYPE_CHECKING, Dict, List +from typing import TYPE_CHECKING, Dict, List, Optional import torch from sglang.srt.managers.expert_distribution import get_global_expert_distribution_recorder @@ -28,8 +28,12 @@ def __init__(self, model_runner: "ModelRunner"): device=model_runner.device, ) self._prepare_end_barrier = PollBasedBarrier(noop=False) + self._ongoing_req: Optional[UpdateExpertLocationReqInput] = None + + def start_prepare(self, req: UpdateExpertLocationReqInput): + assert self._ongoing_req is None + self._ongoing_req = req - def start_prepare(self, expert_location_metadata: ExpertLocationMetadata): interesting_logical_experts_of_layer = _compute_interesting_logical_experts_of_layer( old_expert_location_metadata=get_global_expert_location_metadata(), new_expert_location_metadata=expert_location_metadata, @@ -52,7 +56,7 @@ def _act(self): get_global_expert_distribution_recorder().flush_buffer_depending_on_expert_location_metadata() - get_global_expert_location_metadata().update(expert_location_metadata) + get_global_expert_location_metadata().update(self._ongoing_req.expert_location_metadata) if self._model_runner.tp_rank == 0 and get_bool_env_var( "SGLANG_LOG_EXPERT_LOCATION_METADATA" ): @@ -64,6 +68,9 @@ def _act(self): torch.distributed.barrier() + assert self._ongoing_req is not None + self._ongoing_req = None + def _weight_filter(self, name: str, interesting_logical_experts_of_layer: Dict[int, List[int]]): info: ModelParamNameInfo = self._model_runner.model.get_param_name_info(name) return ( From 0730e8f3b305d9e0a1d2b385a4c39332ca149aeb Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 12:19:48 +0800 Subject: [PATCH 0964/1089] more --- python/sglang/srt/model_executor/expert_location_updater.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/model_executor/expert_location_updater.py b/python/sglang/srt/model_executor/expert_location_updater.py index 244c1cfde78..a2bf8f6e533 100644 --- a/python/sglang/srt/model_executor/expert_location_updater.py +++ b/python/sglang/srt/model_executor/expert_location_updater.py @@ -30,13 +30,13 @@ def __init__(self, model_runner: "ModelRunner"): self._prepare_end_barrier = PollBasedBarrier(noop=False) self._ongoing_req: Optional[UpdateExpertLocationReqInput] = None - def start_prepare(self, req: UpdateExpertLocationReqInput): + def start(self, req: UpdateExpertLocationReqInput): assert self._ongoing_req is None self._ongoing_req = req interesting_logical_experts_of_layer = _compute_interesting_logical_experts_of_layer( old_expert_location_metadata=get_global_expert_location_metadata(), - new_expert_location_metadata=expert_location_metadata, + new_expert_location_metadata=req.expert_location_metadata, ep_rank=self._model_runner.tp_rank, ) From a0fbeebf371a98a679bc4df339c42c4b23630cf6 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 12:20:05 +0800 Subject: [PATCH 0965/1089] more --- python/sglang/srt/model_executor/expert_location_updater.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/sglang/srt/model_executor/expert_location_updater.py b/python/sglang/srt/model_executor/expert_location_updater.py index a2bf8f6e533..3bb4e9c9601 100644 --- a/python/sglang/srt/model_executor/expert_location_updater.py +++ b/python/sglang/srt/model_executor/expert_location_updater.py @@ -51,6 +51,8 @@ def event_loop_step(self): if self._prepare_end_barrier.poll_global_arrived(): self._act() + TODO_return_act_end + def _act(self): torch.distributed.barrier() From 10ac4fb7e54b7c4c628c66238a61c4c0a17135c1 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 14:09:10 +0800 Subject: [PATCH 0966/1089] more --- python/sglang/srt/server_args.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index dcbb8d050a4..421ae102735 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -336,11 +336,8 @@ def __post_init__(self): ) if self.enable_eplb: - self.enable_scheduler_input_blocker = True self.enable_expert_distribution_recorder = True - logger.info( - f"EPLB is enabled. The enable_scheduler_input_blocker and enable_expert_distribution_recorder are automatically enabled." - ) + logger.info(f"EPLB is enabled. The enable_expert_distribution_recorder is automatically enabled.") if self.enable_eplb or (self.init_expert_location is not None): self.ep_dispatch_algorithm = "static" logger.info( From 6c60e040c84546f031b0ae1843e53d5c32a493ab Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 14:09:50 +0800 Subject: [PATCH 0967/1089] more --- python/sglang/srt/managers/tokenizer_manager.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index bd47edfb338..cb448a6b4f7 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -723,15 +723,9 @@ async def update_expert_location(self, obj: UpdateExpertLocationReqInput): async def _update_expert_location_raw(self, expert_location_metadata: ExpertLocationMetadata): self.expert_location_metadata = None - - TODO_prepare - - TODO_rename_to_act - self._send_block_request(BlockReqType.BLOCK) - await self.update_expert_location_communicator.call_send(TODO) - self._send_block_request(BlockReqType.UNBLOCK) - await self.update_expert_location_communicator.call_await() - + await self.update_expert_location_communicator(UpdateExpertLocationReqInput( + expert_location_metadata=expert_location_metadata, + )) self.expert_location_metadata = expert_location_metadata async def update_weights_from_disk( From 44bb9f4f99fd26f8b4cf2a4f78837fab3b37a048 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 14:12:07 +0800 Subject: [PATCH 0968/1089] more --- python/sglang/srt/managers/scheduler.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 55563dea09c..c1e90d6f653 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -605,6 +605,7 @@ def event_loop_normal(self): while True: recv_reqs = self.recv_requests() self.process_input_requests(recv_reqs) + self.model_runner_event_loop_step() batch = self.get_next_batch_to_run() self.cur_batch = batch @@ -627,6 +628,7 @@ def event_loop_overlap(self): while True: recv_reqs = self.recv_requests() self.process_input_requests(recv_reqs) + self.model_runner_event_loop_step() batch = self.get_next_batch_to_run() self.cur_batch = batch @@ -666,6 +668,7 @@ def event_loop_normal_disagg_prefill(self): while True: recv_reqs = self.recv_requests() self.process_input_requests(recv_reqs) + self.model_runner_event_loop_step() self.waiting_queue.extend( self.disagg_prefill_pending_queue.pop_bootstrapped() ) @@ -696,6 +699,7 @@ def event_loop_normal_disagg_decode(self): while True: recv_reqs = self.recv_requests() self.process_input_requests(recv_reqs) + self.model_runner_event_loop_step() # polling and allocating kv cache self.process_decode_queue() batch = self.get_next_disagg_decode_batch_to_run() @@ -801,6 +805,9 @@ def process_input_requests(self, recv_reqs: List): else: self.send_to_tokenizer.send_pyobj(output) + def model_runner_event_loop_step(self): + self.tp_worker.worker.model_runner.event_loop_step() + def handle_generate_request( self, recv_req: TokenizedGenerateReqInput, From 9b0895562fb7ce54657273cc2a033583fccd1e33 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 14:12:26 +0800 Subject: [PATCH 0969/1089] more --- python/sglang/srt/managers/scheduler.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index c1e90d6f653..7ab23c878ff 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -806,7 +806,9 @@ def process_input_requests(self, recv_reqs: List): self.send_to_tokenizer.send_pyobj(output) def model_runner_event_loop_step(self): - self.tp_worker.worker.model_runner.event_loop_step() + outputs = self.tp_worker.worker.model_runner.event_loop_step() + for output in outputs: + self.send_to_tokenizer.send_pyobj(output) def handle_generate_request( self, From d782002089ab6401e798b3ae9dbc77a82e8f8527 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 14:13:09 +0800 Subject: [PATCH 0970/1089] more --- python/sglang/srt/model_executor/model_runner.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 0d2d891d5a5..ebcf5f8a091 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -20,7 +20,7 @@ import os import time from dataclasses import dataclass -from typing import Iterable, List, Optional, Tuple, Union +from typing import Iterable, List, Optional, Tuple, Union, Any import torch import torch.distributed as dist @@ -496,7 +496,10 @@ def load_model(self): ) from None def update_expert_location(self, recv_req: UpdateExpertLocationReqInput): - self._expert_location_updater.act(recv_req) + self._expert_location_updater.start(recv_req) + + def event_loop_step(self) -> List[Any]: + TODO def update_weights_from_disk( self, model_path: str, load_format: str, param_categories: Optional[List[str]] From f67e583f3f7e74adeea39b2fc402ce4505c9b40c Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 14:13:35 +0800 Subject: [PATCH 0971/1089] more --- python/sglang/srt/model_executor/model_runner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index ebcf5f8a091..20b18aee5ac 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -497,9 +497,9 @@ def load_model(self): def update_expert_location(self, recv_req: UpdateExpertLocationReqInput): self._expert_location_updater.start(recv_req) - + def event_loop_step(self) -> List[Any]: - TODO + return self._expert_location_updater.event_loop_step() def update_weights_from_disk( self, model_path: str, load_format: str, param_categories: Optional[List[str]] From 63f3ec4237d0cf241b478a2469f22402abb6de6b Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 14:13:49 +0800 Subject: [PATCH 0972/1089] more --- .../srt/model_executor/expert_location_updater.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/model_executor/expert_location_updater.py b/python/sglang/srt/model_executor/expert_location_updater.py index 3bb4e9c9601..1396f1896c0 100644 --- a/python/sglang/srt/model_executor/expert_location_updater.py +++ b/python/sglang/srt/model_executor/expert_location_updater.py @@ -1,10 +1,10 @@ import logging -from typing import TYPE_CHECKING, Dict, List, Optional +from typing import TYPE_CHECKING, Dict, List, Optional, Any import torch from sglang.srt.managers.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.managers.expert_location import ExpertLocationMetadata -from sglang.srt.managers.io_struct import UpdateExpertLocationReqInput +from sglang.srt.managers.io_struct import UpdateExpertLocationReqInput, UpdateExpertLocationReqOutput from sglang.srt.managers.schedule_batch import get_global_expert_location_metadata from sglang.srt.model_executor.model_weight_updater import ModelWeightUpdater from sglang.srt.model_loader.weight_utils import ModelParamNameInfo, ModelParamNameInfoMoe @@ -44,14 +44,16 @@ def start(self, req: UpdateExpertLocationReqInput): weight_filter=lambda name: self._weight_filter(name, interesting_logical_experts_of_layer), ) - def event_loop_step(self): + def event_loop_step(self) -> List[UpdateExpertLocationReqOutput]: + outputs = [] + if self._model_weight_updater.event_loop_step(): self._prepare_end_barrier.local_arrive() if self._prepare_end_barrier.poll_global_arrived(): - self._act() + outputs.append(self._act()) - TODO_return_act_end + return outputs def _act(self): torch.distributed.barrier() @@ -73,6 +75,8 @@ def _act(self): assert self._ongoing_req is not None self._ongoing_req = None + return UpdateExpertLocationReqOutput() + def _weight_filter(self, name: str, interesting_logical_experts_of_layer: Dict[int, List[int]]): info: ModelParamNameInfo = self._model_runner.model.get_param_name_info(name) return ( From 4a819366c119a9f562c4995eb6928a2da3dc01a5 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 14:14:30 +0800 Subject: [PATCH 0973/1089] more --- .../model_executor/expert_location_updater.py | 4 ++-- .../srt/model_executor/model_weight_updater.py | 17 +++++++---------- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/python/sglang/srt/model_executor/expert_location_updater.py b/python/sglang/srt/model_executor/expert_location_updater.py index 1396f1896c0..3d59c1bcfff 100644 --- a/python/sglang/srt/model_executor/expert_location_updater.py +++ b/python/sglang/srt/model_executor/expert_location_updater.py @@ -1,5 +1,5 @@ import logging -from typing import TYPE_CHECKING, Dict, List, Optional, Any +from typing import TYPE_CHECKING, Dict, List, Optional import torch from sglang.srt.managers.expert_distribution import get_global_expert_distribution_recorder @@ -47,7 +47,7 @@ def start(self, req: UpdateExpertLocationReqInput): def event_loop_step(self) -> List[UpdateExpertLocationReqOutput]: outputs = [] - if self._model_weight_updater.event_loop_step(): + if self._model_weight_updater.poll_prepare_end(): self._prepare_end_barrier.local_arrive() if self._prepare_end_barrier.poll_global_arrived(): diff --git a/python/sglang/srt/model_executor/model_weight_updater.py b/python/sglang/srt/model_executor/model_weight_updater.py index 3999eda074c..56326a0c3ea 100644 --- a/python/sglang/srt/model_executor/model_weight_updater.py +++ b/python/sglang/srt/model_executor/model_weight_updater.py @@ -12,12 +12,12 @@ class ModelWeightUpdater: def __init__( - self, - init_pin_memory: bool, - load_format: str, - model_config: ModelConfig, - model, - device, + self, + init_pin_memory: bool, + load_format: str, + model_config: ModelConfig, + model, + device, ): self._model_config = model_config self._model = model @@ -39,10 +39,7 @@ def start_prepare(self, weight_filter: Callable[[str], bool]): self._state = _StateAwaitMemoryTransfer() - def event_loop_step(self): - TODO_maybe_rename - TODO_maybe_change_output - + def poll_prepare_end(self): memory_transfer_outputs = self._memory_transfer_manager.get_outputs() assert len(memory_transfer_outputs) in {0, 1} if len(memory_transfer_outputs) == 0: From 8e036e1db8df9d1c161bcec98624251632f359df Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 14:15:24 +0800 Subject: [PATCH 0974/1089] more --- python/sglang/srt/model_executor/expert_location_updater.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/model_executor/expert_location_updater.py b/python/sglang/srt/model_executor/expert_location_updater.py index 3d59c1bcfff..b12c067d7ff 100644 --- a/python/sglang/srt/model_executor/expert_location_updater.py +++ b/python/sglang/srt/model_executor/expert_location_updater.py @@ -22,7 +22,7 @@ def __init__(self, model_runner: "ModelRunner"): self._model_runner = model_runner self._model_weight_updater = ModelWeightUpdater( init_pin_memory=TODO, - load_format=TODO, + load_format=model_runner.server_args.load_format, model_config=model_runner.model_config, model=model_runner.model, device=model_runner.device, From 9b501cef4aa10bd9488ec45ba96e04ea11408890 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 14:18:29 +0800 Subject: [PATCH 0975/1089] more --- .../sglang/srt/model_executor/expert_location_updater.py | 2 +- python/sglang/srt/server_args.py | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/model_executor/expert_location_updater.py b/python/sglang/srt/model_executor/expert_location_updater.py index b12c067d7ff..23574f8b8f4 100644 --- a/python/sglang/srt/model_executor/expert_location_updater.py +++ b/python/sglang/srt/model_executor/expert_location_updater.py @@ -21,7 +21,7 @@ class ExpertLocationUpdater: def __init__(self, model_runner: "ModelRunner"): self._model_runner = model_runner self._model_weight_updater = ModelWeightUpdater( - init_pin_memory=TODO, + init_pin_memory=model_runner.server_args.expert_location_updater_mode == "pin_memory", load_format=model_runner.server_args.load_format, model_config=model_runner.model_config, model=model_runner.model, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 421ae102735..e0f91ab3b37 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -165,6 +165,7 @@ class ServerArgs: ep_num_redundant_experts: int = 0 ep_dispatch_algorithm: Optional[Literal["static", "random"]] = None init_expert_location: Optional[str] = None + expert_location_updater_mode: Literal["pin_memory", "pageable_memory"] = "pin_memory" enable_eplb: bool = False eplb_storage_dir: str = "/tmp/eplb_storage" eplb_rebalance_period: Optional[int] = None @@ -1147,6 +1148,12 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.init_expert_location, help="Initial location of EP experts.", ) + parser.add_argument( + "--expert-location-updater-mode", + type=str, + default=ServerArgs.expert_location_updater_mode, + help="Mode of ExpertLocationUpdater, can be `pin_memory` (put weights in pinned memory at startup, thus faster but takes more host memory) or `pageable_memory` (put weights on pageable memory, thus slower but takes less host memory)", + ) parser.add_argument( "--enable-eplb", action="store_true", From 0ab750b536dd6660453054bda48e499f6c300d19 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 14:19:34 +0800 Subject: [PATCH 0976/1089] more --- python/sglang/srt/managers/tokenizer_manager.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index cb448a6b4f7..54e9aba260c 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -713,10 +713,11 @@ async def update_expert_location(self, obj: UpdateExpertLocationReqInput): layer_id_lens = list(range(0, num_layers, 10)) + [num_layers] - for layer_id_len in layer_id_lens: + for layer_id_end in layer_id_lens: + logger.info(f"update_expert_location handling 0~{layer_id_end} layers") partial_expert_location_metadata = copy.deepcopy(old_expert_location_metadata) partial_expert_location_metadata.update(obj.expert_location_metadata, layer_id_start=0, - layer_id_len=layer_id_len) + layer_id_len=layer_id_end) await self._update_expert_location_raw( expert_location_metadata=partial_expert_location_metadata, ) From 9a49ec61042269e55446f835434ee3f8ad246519 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 14:21:12 +0800 Subject: [PATCH 0977/1089] more --- test/srt/test_eplb.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index 3ea0a7bd1b9..eaf2adf40fc 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -23,8 +23,8 @@ class TestEPLB(CustomTestCase): - def test_eplb_e2e(self): - print("Action: test_eplb_e2e") + def test_eplb_start_rebalance_restart(self): + print("Action: test_eplb_start_rebalance_restart") with tempfile.TemporaryDirectory() as tmpdir: engine_kwargs = dict( model_path=DEFAULT_MLA_MODEL_NAME_FOR_TEST, From b48a19164504d8c62a52ccbf6e513262fa64bca9 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 14:21:31 +0800 Subject: [PATCH 0978/1089] more --- test/srt/test_eplb.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index eaf2adf40fc..f791b840731 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -23,6 +23,9 @@ class TestEPLB(CustomTestCase): + def test_eplb_many_rebalance(self): + TODO + def test_eplb_start_rebalance_restart(self): print("Action: test_eplb_start_rebalance_restart") with tempfile.TemporaryDirectory() as tmpdir: From d74299c95a54765370d67461ff9536c3d721227a Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 14:22:16 +0800 Subject: [PATCH 0979/1089] more --- test/srt/test_eplb.py | 34 ++++++++++++++++++++++++++++++---- 1 file changed, 30 insertions(+), 4 deletions(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index f791b840731..ba1d2e1fbf2 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -23,8 +23,34 @@ class TestEPLB(CustomTestCase): - def test_eplb_many_rebalance(self): - TODO + def test_eplb_many_rebalances(self): + print("Action: test_eplb_many_rebalances") + with tempfile.TemporaryDirectory() as tmpdir: + engine_kwargs = dict( + model_path=DEFAULT_MLA_MODEL_NAME_FOR_TEST, + trust_remote_code=True, + enable_eplb=True, + eplb_storage_dir=tmpdir, + ep_num_redundant_experts=4, + enable_dp_attention=True, + enable_deepep_moe=True, + deepep_mode="normal", + disable_cuda_graph=True, + enable_scheduler_input_blocker=True, + disable_overlap_schedule=True, + tp_size=2, + dp_size=2, + log_level="info", + ) + + print(f"Action: start engine") + engine = sgl.Engine(**engine_kwargs) + + TODO + + print(f"Action: shutdown engine") + engine.shutdown() + del engine def test_eplb_start_rebalance_restart(self): print("Action: test_eplb_start_rebalance_restart") @@ -40,7 +66,7 @@ def test_eplb_start_rebalance_restart(self): deepep_mode="normal", disable_cuda_graph=True, enable_scheduler_input_blocker=True, - disable_overlap_schedule=True, # TODO + disable_overlap_schedule=True, tp_size=2, dp_size=2, log_level="info", @@ -84,7 +110,7 @@ def test_eplb_init_expert_location_and_save_expert_distribution(self): enable_deepep_moe=True, deepep_mode="normal", disable_cuda_graph=True, - disable_overlap_schedule=True, # TODO + disable_overlap_schedule=True, tp_size=2, dp_size=2, log_level="info", From 121bc451fa518fc8a96b6eae39dce30bdb6a0592 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 14:22:47 +0800 Subject: [PATCH 0980/1089] more --- test/srt/test_eplb.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index ba1d2e1fbf2..82de402b08f 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -1,3 +1,4 @@ +import asyncio import json import tempfile import unittest @@ -25,6 +26,10 @@ class TestEPLB(CustomTestCase): def test_eplb_many_rebalances(self): print("Action: test_eplb_many_rebalances") + + async def _main_async(): + TODO + with tempfile.TemporaryDirectory() as tmpdir: engine_kwargs = dict( model_path=DEFAULT_MLA_MODEL_NAME_FOR_TEST, @@ -41,12 +46,14 @@ def test_eplb_many_rebalances(self): tp_size=2, dp_size=2, log_level="info", + disable_radix_cache=True, ) print(f"Action: start engine") engine = sgl.Engine(**engine_kwargs) - TODO + loop = asyncio.get_event_loop() + loop.run_until_complete(_main_async) print(f"Action: shutdown engine") engine.shutdown() From 9bedaa2c9015a6162151dbeba77c3aa509aa839d Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 14:23:17 +0800 Subject: [PATCH 0981/1089] more --- test/srt/test_eplb.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index 82de402b08f..a13e13516e5 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -28,7 +28,10 @@ def test_eplb_many_rebalances(self): print("Action: test_eplb_many_rebalances") async def _main_async(): - TODO + await asyncio.gather( + asyncio.create_task(TODO), + asyncio.create_task(TODO), + ) with tempfile.TemporaryDirectory() as tmpdir: engine_kwargs = dict( From 4ebb695a880113babc84f1955991bd9c9f7624dd Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 14:23:42 +0800 Subject: [PATCH 0982/1089] more --- test/srt/test_eplb.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index a13e13516e5..eeecf54800a 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -29,10 +29,16 @@ def test_eplb_many_rebalances(self): async def _main_async(): await asyncio.gather( - asyncio.create_task(TODO), - asyncio.create_task(TODO), + asyncio.create_task(_task_generate()), + asyncio.create_task(_task_rebalance()), ) + async def _task_generate(): + TODO + + async def _task_rebalance(): + TODO + with tempfile.TemporaryDirectory() as tmpdir: engine_kwargs = dict( model_path=DEFAULT_MLA_MODEL_NAME_FOR_TEST, From f45df0ada0f4bd72e348d38aeebb6f3ada6f4ba5 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 14:24:12 +0800 Subject: [PATCH 0983/1089] more --- test/srt/test_eplb.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index eeecf54800a..1509cac3cdc 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -27,6 +27,8 @@ class TestEPLB(CustomTestCase): def test_eplb_many_rebalances(self): print("Action: test_eplb_many_rebalances") + num_rebalance = 20 + async def _main_async(): await asyncio.gather( asyncio.create_task(_task_generate()), @@ -37,7 +39,8 @@ async def _task_generate(): TODO async def _task_rebalance(): - TODO + for i in range(num_rebalance): + await engine.tokenizer_manager.eplb_rebalance() with tempfile.TemporaryDirectory() as tmpdir: engine_kwargs = dict( From 8e027e64999244ab0618bb76c693d3ea2bb85f05 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 14:25:37 +0800 Subject: [PATCH 0984/1089] more --- test/srt/test_eplb.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index 1509cac3cdc..a74f5090ee7 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -36,7 +36,13 @@ async def _main_async(): ) async def _task_generate(): - TODO + tasks = [] + async for request in _get_requests(): + tasks.append(asyncio.create_task(engine.async_generate( + prompt=TODO, + sampling_params=dict(temperature=0), + ))) + TODO_test_result async def _task_rebalance(): for i in range(num_rebalance): From cd874a72fa3d09013a8ccd86a56cfa9b79768bc2 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 14:31:53 +0800 Subject: [PATCH 0985/1089] more --- test/srt/test_eplb.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index a74f5090ee7..7c8efd8fe6d 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -3,7 +3,9 @@ import tempfile import unittest from pathlib import Path +from typing import List +import numpy as np import sglang as sgl import torch from python.sglang.srt.managers.expert_location import compute_logical_to_rank_dispatch_physical_map @@ -281,5 +283,12 @@ def _compute_trivial_expert_locations(ep_num_redundant_experts: int): ) +async def _get_request(input_requests: List[str], request_rate: float): + for request in input_requests: + yield request + interval = np.random.exponential(1.0 / request_rate) + await asyncio.sleep(interval) + + if __name__ == "__main__": unittest.main() From 9b5bdfe4bcc4c60ecb3a00c9dd944fb8a3559e17 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 14:33:25 +0800 Subject: [PATCH 0986/1089] more --- test/srt/test_eplb.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index 7c8efd8fe6d..fae6629a2bc 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -30,6 +30,10 @@ def test_eplb_many_rebalances(self): print("Action: test_eplb_many_rebalances") num_rebalance = 20 + request_rate = 20 + requests = [ + dict(prompt=TODO), + ] async def _main_async(): await asyncio.gather( @@ -39,10 +43,10 @@ async def _main_async(): async def _task_generate(): tasks = [] - async for request in _get_requests(): + async for request in _yield_with_poisson_process(requests, action_rate=request_rate): tasks.append(asyncio.create_task(engine.async_generate( - prompt=TODO, - sampling_params=dict(temperature=0), + prompt=request["prompt"], + sampling_params=dict(temperature=0, max_new_tokens=8), ))) TODO_test_result @@ -283,10 +287,10 @@ def _compute_trivial_expert_locations(ep_num_redundant_experts: int): ) -async def _get_request(input_requests: List[str], request_rate: float): - for request in input_requests: - yield request - interval = np.random.exponential(1.0 / request_rate) +async def _yield_with_poisson_process(items: List, action_rate: float): + for item in items: + yield item + interval = np.random.exponential(1.0 / action_rate) await asyncio.sleep(interval) From 9b506b091981c58dca9fdb303a1362b1ff8faa57 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 14:34:19 +0800 Subject: [PATCH 0987/1089] more --- test/srt/test_eplb.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index fae6629a2bc..b776e4d1d3c 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -31,8 +31,11 @@ def test_eplb_many_rebalances(self): num_rebalance = 20 request_rate = 20 - requests = [ - dict(prompt=TODO), + prompts = [ + TODO, + ] + expect_outputs = [ + TODO, ] async def _main_async(): @@ -43,11 +46,13 @@ async def _main_async(): async def _task_generate(): tasks = [] - async for request in _yield_with_poisson_process(requests, action_rate=request_rate): + async for prompt in _yield_with_poisson_process(prompts, action_rate=request_rate): tasks.append(asyncio.create_task(engine.async_generate( - prompt=request["prompt"], + prompt=prompt, sampling_params=dict(temperature=0, max_new_tokens=8), ))) + + outputs = await asyncio.gather(*tasks) TODO_test_result async def _task_rebalance(): From 003815c14c021d5c543cc0d090508a7b492081d3 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 14:35:31 +0800 Subject: [PATCH 0988/1089] more --- test/srt/test_eplb.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index b776e4d1d3c..ec42c7f1438 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -1,5 +1,6 @@ import asyncio import json +import random import tempfile import unittest from pathlib import Path @@ -31,11 +32,9 @@ def test_eplb_many_rebalances(self): num_rebalance = 20 request_rate = 20 - prompts = [ - TODO, - ] - expect_outputs = [ - TODO, + content_duplicate_num = 20 + contents_raw = [ + dict(prompt=TODO, expect_output=TODO), ] async def _main_async(): @@ -45,14 +44,17 @@ async def _main_async(): ) async def _task_generate(): + contents_duplicated = contents_raw * content_duplicate_num + random.shuffle(contents_duplicated) + tasks = [] - async for prompt in _yield_with_poisson_process(prompts, action_rate=request_rate): + async for content in _yield_with_poisson_process(contents_duplicated, action_rate=request_rate): tasks.append(asyncio.create_task(engine.async_generate( - prompt=prompt, + prompt=content["prompt"], sampling_params=dict(temperature=0, max_new_tokens=8), ))) - outputs = await asyncio.gather(*tasks) + actual_outputs = await asyncio.gather(*tasks) TODO_test_result async def _task_rebalance(): From 3768d14596444652a1772ce7dc82f37c5adaa2d4 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 14:35:48 +0800 Subject: [PATCH 0989/1089] more --- test/srt/test_eplb.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index ec42c7f1438..0de05debcb9 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -55,6 +55,8 @@ async def _task_generate(): ))) actual_outputs = await asyncio.gather(*tasks) + print(f"{actual_outputs=}") + actual_output_texts = [x["text"] for x in actual_outputs] TODO_test_result async def _task_rebalance(): From d011216946eddff8408791daecc46aa94574a804 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 14:36:20 +0800 Subject: [PATCH 0990/1089] more --- test/srt/test_eplb.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index 0de05debcb9..970646158e7 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -55,9 +55,10 @@ async def _task_generate(): ))) actual_outputs = await asyncio.gather(*tasks) - print(f"{actual_outputs=}") actual_output_texts = [x["text"] for x in actual_outputs] - TODO_test_result + expect_output_texts = [x["expect_output"] for x in contents_duplicated] + print(f"{actual_output_texts=} {expect_output_texts=}") + self.assertEqual(actual_output_texts, expect_output_texts) async def _task_rebalance(): for i in range(num_rebalance): From f4c5303c82c64677ec2e1b4127e85a3b8e4c5c6c Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 14:37:08 +0800 Subject: [PATCH 0991/1089] more --- test/srt/test_eplb.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index 970646158e7..bde7d040e8b 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -34,7 +34,18 @@ def test_eplb_many_rebalances(self): request_rate = 20 content_duplicate_num = 20 contents_raw = [ - dict(prompt=TODO, expect_output=TODO), + dict( + prompt="1+1=2, 1+2=3, 1+3=4, 1+4=5, 1+5=", + expect_output="TODO", + ), + dict( + prompt="2*1=2, 2*2=4, 2*3=6, 2*4=", + expect_output="TODO", + ), + dict( + prompt="One plus one is two, one plus two is three, one plus three is", + expect_output="TODO", + ), ] async def _main_async(): From 007f8b172494c616e49bbe5b76e63e4cfaa001e4 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 14:38:08 +0800 Subject: [PATCH 0992/1089] more --- test/srt/test_eplb.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index bde7d040e8b..ed78c555fea 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -42,6 +42,14 @@ def test_eplb_many_rebalances(self): prompt="2*1=2, 2*2=4, 2*3=6, 2*4=", expect_output="TODO", ), + dict( + prompt="10*1=10, 10*2=20, 10*3=30, 10*4=40, 10*5=50, 10*6=", + expect_output="TODO", + ), + dict( + prompt="2/2=1, 4/2=2, 6/2=3, 8/2=", + expect_output="TODO", + ), dict( prompt="One plus one is two, one plus two is three, one plus three is", expect_output="TODO", From af6f5130c5eb21016e19974f15b1322b765787d1 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 14:38:31 +0800 Subject: [PATCH 0993/1089] more --- test/srt/test_eplb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index ed78c555fea..f0541c092d2 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -70,7 +70,7 @@ async def _task_generate(): async for content in _yield_with_poisson_process(contents_duplicated, action_rate=request_rate): tasks.append(asyncio.create_task(engine.async_generate( prompt=content["prompt"], - sampling_params=dict(temperature=0, max_new_tokens=8), + sampling_params=dict(temperature=0, max_new_tokens=4), ))) actual_outputs = await asyncio.gather(*tasks) From eadf4ae6d20da40f85c2ebda06306396d7cf8b80 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 14:39:45 +0800 Subject: [PATCH 0994/1089] more --- test/srt/test_eplb.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index f0541c092d2..4a22bda5d5e 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -28,6 +28,12 @@ class TestEPLB(CustomTestCase): def test_eplb_many_rebalances(self): + self._test_eplb_many_rebalances_core() + + def test_eplb_many_rebalances_baseline(self): + self._test_eplb_many_rebalances_core(enable_eplb=False) + + def _test_eplb_many_rebalances_core(self, enable_eplb: bool = True): print("Action: test_eplb_many_rebalances") num_rebalance = 20 @@ -80,6 +86,10 @@ async def _task_generate(): self.assertEqual(actual_output_texts, expect_output_texts) async def _task_rebalance(): + if not enable_eplb: + print("task_rebalance skip since not enable eplb") + return + for i in range(num_rebalance): await engine.tokenizer_manager.eplb_rebalance() @@ -87,7 +97,7 @@ async def _task_rebalance(): engine_kwargs = dict( model_path=DEFAULT_MLA_MODEL_NAME_FOR_TEST, trust_remote_code=True, - enable_eplb=True, + enable_eplb=enable_eplb, eplb_storage_dir=tmpdir, ep_num_redundant_experts=4, enable_dp_attention=True, From 238e32661edc67d71809a8b27c099b558090820f Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 14:41:06 +0800 Subject: [PATCH 0995/1089] more --- test/srt/test_eplb.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index 4a22bda5d5e..20e6f50a3af 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -122,7 +122,13 @@ async def _task_rebalance(): engine.shutdown() del engine - def test_eplb_start_rebalance_restart(self): + def test_eplb_start_rebalance_restart_mode_pin_memory(self): + self._test_eplb_start_rebalance_restart_core(expert_location_updater_mode="pin_memory") + + def test_eplb_start_rebalance_restart_mode_pageable_memory(self): + self._test_eplb_start_rebalance_restart_core(expert_location_updater_mode="pageable_memory") + + def _test_eplb_start_rebalance_restart_core(self, expert_location_updater_mode: str): print("Action: test_eplb_start_rebalance_restart") with tempfile.TemporaryDirectory() as tmpdir: engine_kwargs = dict( @@ -140,6 +146,7 @@ def test_eplb_start_rebalance_restart(self): tp_size=2, dp_size=2, log_level="info", + expert_location_updater_mode=expert_location_updater_mode, ) print(f"Action: start engine") From 1d858da40627367974336abc2ccca101d79a1e2b Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 15:16:12 +0800 Subject: [PATCH 0996/1089] more --- python/sglang/srt/managers/expert_location.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index b5d03f11d69..4c0389298ba 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -275,11 +275,12 @@ def compute_logical_to_rank_dispatch_physical_map( num_local_physical_experts = num_physical_experts // num_gpus num_layers, num_logical_experts, _ = logical_to_all_physical_map.shape + dtype = logical_to_all_physical_map.dtype logical_to_rank_dispatch_physical_map = torch.full( size=(num_gpus, num_layers, num_logical_experts), fill_value=-1, - dtype=logical_to_all_physical_map.dtype, + dtype=dtype, ) for layer_id in range(num_layers): @@ -300,7 +301,7 @@ def compute_logical_to_rank_dispatch_physical_map( num_remain = torch.sum(partial_map == -1).item() partial_map[partial_map == -1] = torch.tensor( - _fair_choices(candidate_physical_expert_ids, k=num_remain, r=r)) + _fair_choices(candidate_physical_expert_ids, k=num_remain, r=r), dtype=dtype) assert torch.all(logical_to_rank_dispatch_physical_map != -1) return logical_to_rank_dispatch_physical_map From d25429ab7689166163870903c26f6dc3d7cb379c Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 15:18:15 +0800 Subject: [PATCH 0997/1089] more --- python/sglang/srt/managers/expert_location.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 4c0389298ba..7e6a26cc9ab 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -288,7 +288,7 @@ def compute_logical_to_rank_dispatch_physical_map( candidate_physical_expert_ids = ExpertLocationMetadata.logical_to_all_physical_raw( logical_to_all_physical_map, layer_id, logical_expert_id ) - partial_map = logical_to_all_physical_map[:, layer_id, logical_expert_id] + output_partial = logical_to_rank_dispatch_physical_map[:, layer_id, logical_expert_id] for gpu_id in range(num_gpus): same_gpu_physical_expert_ids = [ @@ -297,10 +297,10 @@ def compute_logical_to_rank_dispatch_physical_map( if _compute_gpu_id_of_physical_expert(physical_expert_id, num_local_physical_experts) == gpu_id ] if len(same_gpu_physical_expert_ids) > 0: - partial_map[gpu_id] = same_gpu_physical_expert_ids[0] + output_partial[gpu_id] = same_gpu_physical_expert_ids[0] - num_remain = torch.sum(partial_map == -1).item() - partial_map[partial_map == -1] = torch.tensor( + num_remain = torch.sum(output_partial == -1).item() + output_partial[output_partial == -1] = torch.tensor( _fair_choices(candidate_physical_expert_ids, k=num_remain, r=r), dtype=dtype) assert torch.all(logical_to_rank_dispatch_physical_map != -1) From 068002678b6d7b0044d7cd64d7ee2e5f2874643a Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 15:18:49 +0800 Subject: [PATCH 0998/1089] more --- python/sglang/srt/managers/expert_location.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 7e6a26cc9ab..23bdc4b6ae4 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -265,12 +265,12 @@ def _pad_nested_array(arr, pad_value): return padded +# This is rarely called, so we use for loops for maximum clarity def compute_logical_to_rank_dispatch_physical_map( logical_to_all_physical_map: torch.Tensor, num_gpus: int, num_physical_experts: int, ): - # This is rarely called, so we use for loops for maximum clarity r = random.Random() num_local_physical_experts = num_physical_experts // num_gpus From bad67f48c70ec2a1abfdb8fc51e5e8edcd251ef9 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 15:19:56 +0800 Subject: [PATCH 0999/1089] more --- python/sglang/srt/model_executor/model_runner.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 20b18aee5ac..1d36d7717d4 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -33,7 +33,7 @@ initialize_model_parallel, set_custom_all_reduce, ) -from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state +from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state, get_world_group from sglang.srt.layers.dp_attention import ( get_attention_tp_group, get_attention_tp_size, @@ -194,7 +194,11 @@ def __init__( # Get memory before model loading min_per_gpu_memory = self.init_torch_distributed() - [expert_location_metadata] = broadcast_pyobj(data=[expert_location_metadata], rank=torch.distributed.get_rank()) + [expert_location_metadata] = broadcast_pyobj( + data=[expert_location_metadata], + rank=torch.distributed.get_rank(), + dist_group=get_world_group().cpu_group, + ) expert_location_metadata.to(server_args.device) set_global_expert_location_metadata(expert_location_metadata) if self.tp_rank == 0 and get_bool_env_var( From b72221f2182177ce69c258023afd6cb4ce60e4d4 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 15:20:44 +0800 Subject: [PATCH 1000/1089] more --- python/sglang/srt/poll_based_barrier.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/poll_based_barrier.py b/python/sglang/srt/poll_based_barrier.py index 40d7fbe82c8..2f30fffddc8 100644 --- a/python/sglang/srt/poll_based_barrier.py +++ b/python/sglang/srt/poll_based_barrier.py @@ -1,5 +1,7 @@ import torch +from sglang.srt.distributed import get_world_group + class PollBasedBarrier: def __init__(self, noop: bool = False): @@ -19,7 +21,7 @@ def poll_global_arrived(self) -> bool: def _compute_global_arrived(self) -> bool: local_arrived = self._noop or self._local_arrived - global_arrived = torch.tensor(local_arrived).cuda() + global_arrived = torch.tensor(local_arrived) # Can optimize if bottleneck - torch.distributed.all_reduce(global_arrived, torch.distributed.ReduceOp.MIN) - return global_arrived.cpu().item() + torch.distributed.all_reduce(global_arrived, torch.distributed.ReduceOp.MIN, group=get_world_group().cpu_group) + return global_arrived.item() From a77d21d1dd7cc02be088c3198559afefa6081b8c Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 15:23:45 +0800 Subject: [PATCH 1001/1089] more --- test/srt/test_expert_distribution.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/srt/test_expert_distribution.py b/test/srt/test_expert_distribution.py index d6caca23bf5..e523c9ad865 100755 --- a/test/srt/test_expert_distribution.py +++ b/test/srt/test_expert_distribution.py @@ -3,7 +3,6 @@ import requests import torch - from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST, @@ -27,7 +26,7 @@ def test_expert_distribution_record(self): self._execute_core(**info) def _execute_core( - self, model_path: str, mode_detail: bool = False, tp_size: int = 1 + self, model_path: str, mode_detail: bool = False, tp_size: int = 1 ): """Test expert distribution record endpoints""" os.environ["SGLANG_EXPERT_DISTRIBUTION_RECORDER_DETAIL"] = ( @@ -41,6 +40,7 @@ def _execute_core( "--trust-remote-code", "--tp-size", str(tp_size), + "--enable-expert-distribution-recorder", ], ) From 352887c6138c09f3d047f8c6c69405207d5bb977 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 15:23:59 +0800 Subject: [PATCH 1002/1089] more --- test/srt/test_expert_location.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/srt/test_expert_location.py b/test/srt/test_expert_location.py index 1078b507aa3..72144f4d904 100644 --- a/test/srt/test_expert_location.py +++ b/test/srt/test_expert_location.py @@ -1,8 +1,8 @@ import unittest import torch -from python.sglang.srt.managers.expert_location import compute_logical_to_rank_dispatch_physical_map -from python.sglang.test.test_utils import CustomTestCase +from sglang.srt.managers.expert_location import compute_logical_to_rank_dispatch_physical_map +from sglang.test.test_utils import CustomTestCase class TestExpertLocation(CustomTestCase): From 27f4a148bdb2ccd66f00fa30ea01af8e8b6b9530 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 15:25:40 +0800 Subject: [PATCH 1003/1089] more --- test/srt/test_expert_location.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/test/srt/test_expert_location.py b/test/srt/test_expert_location.py index 72144f4d904..480dd31faae 100644 --- a/test/srt/test_expert_location.py +++ b/test/srt/test_expert_location.py @@ -12,7 +12,12 @@ def test_compute_logical_to_rank_dispatch_physical_map(self): # Identity map ( [[[0], [1], [2], [3], [4], [5], [6], [7]]], - [[]], # TODO + [ + [[0, 1, 2, 3, 4, 5, 6, 7]], + [[0, 1, 2, 3, 4, 5, 6, 7]], + [[0, 1, 2, 3, 4, 5, 6, 7]], + [[0, 1, 2, 3, 4, 5, 6, 7]], + ], ), # Identity map + consider redundant experts ( From 5ae31c0bb545920d5b0f419068282a827a907c94 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 15:26:07 +0800 Subject: [PATCH 1004/1089] more --- test/srt/test_expert_distribution.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/srt/test_expert_distribution.py b/test/srt/test_expert_distribution.py index e523c9ad865..75b677b127e 100755 --- a/test/srt/test_expert_distribution.py +++ b/test/srt/test_expert_distribution.py @@ -26,7 +26,7 @@ def test_expert_distribution_record(self): self._execute_core(**info) def _execute_core( - self, model_path: str, mode_detail: bool = False, tp_size: int = 1 + self, model_path: str, mode_detail: bool = False, tp_size: int = 1 ): """Test expert distribution record endpoints""" os.environ["SGLANG_EXPERT_DISTRIBUTION_RECORDER_DETAIL"] = ( @@ -41,6 +41,7 @@ def _execute_core( "--tp-size", str(tp_size), "--enable-expert-distribution-recorder", + "--disable-cuda-graph", ], ) From 3c21a84eaeeef1a143ec8a644736095827ee8e07 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 15:27:27 +0800 Subject: [PATCH 1005/1089] more --- test/srt/test_expert_location.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/test/srt/test_expert_location.py b/test/srt/test_expert_location.py index 480dd31faae..583f6a4b0ed 100644 --- a/test/srt/test_expert_location.py +++ b/test/srt/test_expert_location.py @@ -22,7 +22,10 @@ def test_compute_logical_to_rank_dispatch_physical_map(self): # Identity map + consider redundant experts ( [[[0, 8], [1, 9], [2, 10], [3, 11], [4, -1], [5, -1], [6, -1], [7, -1]]], - [[]], # TODO + [[[0, 1, 2, 3, 4, 5, 6, 7]], + [[8, 1, 2, 3, 4, 5, 6, 7]], + [[8, 9, 10, 11, 4, 5, 6, 7]], + [[0, 9, 10, 11, 4, 5, 6, 7]]], ), # One logical expert is put on ALL gpus ( @@ -47,9 +50,9 @@ def test_compute_logical_to_rank_dispatch_physical_map(self): logical_to_all_physical_map=torch.tensor(logical_to_all_physical_map), num_gpus=4, num_physical_experts=12, - ) + ).tolist() print(f"{actual_output=} {expect_output=}") - self.assertEqual(actual_output.tolist(), expect_output) + self.assertEqual(actual_output, expect_output) if __name__ == "__main__": From f8559b7046f2f26e06e96a98a77820e529cd1f0a Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 15:27:49 +0800 Subject: [PATCH 1006/1089] more --- python/sglang/srt/managers/expert_location.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 23bdc4b6ae4..21b9a016a80 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -270,8 +270,9 @@ def compute_logical_to_rank_dispatch_physical_map( logical_to_all_physical_map: torch.Tensor, num_gpus: int, num_physical_experts: int, + seed: int = 42, ): - r = random.Random() + r = random.Random(seed) num_local_physical_experts = num_physical_experts // num_gpus num_layers, num_logical_experts, _ = logical_to_all_physical_map.shape From 12e5ca6ecda2a4a76d0cc835f274f1978c861a34 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 15:28:34 +0800 Subject: [PATCH 1007/1089] more --- test/srt/test_expert_distribution.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/srt/test_expert_distribution.py b/test/srt/test_expert_distribution.py index 75b677b127e..a1a4daf2cc4 100755 --- a/test/srt/test_expert_distribution.py +++ b/test/srt/test_expert_distribution.py @@ -42,6 +42,7 @@ def _execute_core( str(tp_size), "--enable-expert-distribution-recorder", "--disable-cuda-graph", + "--disable-overlap-schedule", ], ) From 05cc93af7190a9228c3dec3c9fd883b670eb30f1 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 15:29:12 +0800 Subject: [PATCH 1008/1089] more --- test/srt/test_expert_location.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/test/srt/test_expert_location.py b/test/srt/test_expert_location.py index 583f6a4b0ed..a4069c7b93a 100644 --- a/test/srt/test_expert_location.py +++ b/test/srt/test_expert_location.py @@ -22,10 +22,8 @@ def test_compute_logical_to_rank_dispatch_physical_map(self): # Identity map + consider redundant experts ( [[[0, 8], [1, 9], [2, 10], [3, 11], [4, -1], [5, -1], [6, -1], [7, -1]]], - [[[0, 1, 2, 3, 4, 5, 6, 7]], - [[8, 1, 2, 3, 4, 5, 6, 7]], - [[8, 9, 10, 11, 4, 5, 6, 7]], - [[0, 9, 10, 11, 4, 5, 6, 7]]], + [[[0, 1, 2, 11, 4, 5, 6, 7]], [[8, 9, 2, 3, 4, 5, 6, 7]], + [[8, 1, 10, 3, 4, 5, 6, 7]], [[0, 9, 10, 11, 4, 5, 6, 7]]], ), # One logical expert is put on ALL gpus ( From 7929e8d8780ba7b1eed898a6b87a3db0e65d5cf6 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 15:30:22 +0800 Subject: [PATCH 1009/1089] more --- test/srt/test_expert_location.py | 49 ++++++++++++++++++-------------- 1 file changed, 28 insertions(+), 21 deletions(-) diff --git a/test/srt/test_expert_location.py b/test/srt/test_expert_location.py index a4069c7b93a..71cc38da8b3 100644 --- a/test/srt/test_expert_location.py +++ b/test/srt/test_expert_location.py @@ -8,48 +8,55 @@ class TestExpertLocation(CustomTestCase): def test_compute_logical_to_rank_dispatch_physical_map(self): # 8 logical expert - for logical_to_all_physical_map, expect_output in [ + cases = [ # Identity map ( - [[[0], [1], [2], [3], [4], [5], [6], [7]]], - [ - [[0, 1, 2, 3, 4, 5, 6, 7]], - [[0, 1, 2, 3, 4, 5, 6, 7]], - [[0, 1, 2, 3, 4, 5, 6, 7]], - [[0, 1, 2, 3, 4, 5, 6, 7]], - ], + [[[0], [1], [2], [3], [4], [5], [6], [7]]], + [ + [[0, 1, 2, 3, 4, 5, 6, 7]], + [[0, 1, 2, 3, 4, 5, 6, 7]], + [[0, 1, 2, 3, 4, 5, 6, 7]], + [[0, 1, 2, 3, 4, 5, 6, 7]], + ], ), # Identity map + consider redundant experts ( - [[[0, 8], [1, 9], [2, 10], [3, 11], [4, -1], [5, -1], [6, -1], [7, -1]]], - [[[0, 1, 2, 11, 4, 5, 6, 7]], [[8, 9, 2, 3, 4, 5, 6, 7]], - [[8, 1, 10, 3, 4, 5, 6, 7]], [[0, 9, 10, 11, 4, 5, 6, 7]]], + [[[0, 8], [1, 9], [2, 10], [3, 11], [4, -1], [5, -1], [6, -1], [7, -1]]], + [[[0, 1, 2, 11, 4, 5, 6, 7]], [[8, 9, 2, 3, 4, 5, 6, 7]], + [[8, 1, 10, 3, 4, 5, 6, 7]], [[0, 9, 10, 11, 4, 5, 6, 7]]], ), # One logical expert is put on ALL gpus ( - [[[0, 3, 6, 9], [1, -1, -1, -1], [2, -1, -1, -1], [4, -1, -1, -1], [5, -1, -1, -1], [7, -1, -1, -1], - [8, -1, -1, -1], [10, -1, -1, -1]]], - [[]], # TODO + [[[0, 3, 6, 9], [1, -1, -1, -1], [2, -1, -1, -1], [4, -1, -1, -1], [5, -1, -1, -1], [7, -1, -1, -1], + [8, -1, -1, -1], [10, -1, -1, -1]]], + [[]], # TODO ), # One logical expert is put multiple times on ONE gpu ( - [[[0, 1, 2], [3, -1, -1], [4, -1, -1], [5, -1, -1], [6, -1, -1], [7, -1, -1], [8, -1, -1], - [9, -1, -1]]], - [[]], # TODO + [[[0, 1, 2], [3, -1, -1], [4, -1, -1], [5, -1, -1], [6, -1, -1], [7, -1, -1], [8, -1, -1], + [9, -1, -1]]], + [[]], # TODO ), # Random ( - [[[4, 11, -1], [5, 9, 0], [6, -1, -1], [8, -1, -1], [1, -1, -1], [10, -1, -1], [2, 3, -1], - [7, -1, -1]]], - [[]], # TODO + [[[4, 11, -1], [5, 9, 0], [6, -1, -1], [8, -1, -1], [1, -1, -1], [10, -1, -1], [2, 3, -1], + [7, -1, -1]]], + [[]], # TODO ), - ]: + ] + + actual_outputs = [] + + for logical_to_all_physical_map, expect_output in cases: actual_output = compute_logical_to_rank_dispatch_physical_map( logical_to_all_physical_map=torch.tensor(logical_to_all_physical_map), num_gpus=4, num_physical_experts=12, ).tolist() + actual_outputs.append(actual_output) print(f"{actual_output=} {expect_output=}") + + for (logical_to_all_physical_map, expect_output), actual_output in zip(cases, actual_outputs): self.assertEqual(actual_output, expect_output) From 5668d1cc6d7749d6ed31ac932c4ae2a8971da840 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 15:30:38 +0800 Subject: [PATCH 1010/1089] more --- python/sglang/srt/managers/expert_distribution.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 7da177999f3..a750f4b6c88 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -506,7 +506,7 @@ def dump(self): ) def flush_buffer_depending_on_expert_location_metadata(self): - self._logical_count += _convert_global_physical_count_to_logical_count(self._buffer_global_physical_count) + self._logical_count += _convert_global_physical_count_to_logical_count(self._buffer_global_physical_count, expert_location_metadata=self._expert_location_metadata) self._buffer_global_physical_count[...] = 0 From 1c424cc655c703387b303f569bc096fb049d42cb Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 15:31:21 +0800 Subject: [PATCH 1011/1089] more --- test/srt/test_expert_location.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/test/srt/test_expert_location.py b/test/srt/test_expert_location.py index 71cc38da8b3..11cbe6372a3 100644 --- a/test/srt/test_expert_location.py +++ b/test/srt/test_expert_location.py @@ -29,19 +29,22 @@ def test_compute_logical_to_rank_dispatch_physical_map(self): ( [[[0, 3, 6, 9], [1, -1, -1, -1], [2, -1, -1, -1], [4, -1, -1, -1], [5, -1, -1, -1], [7, -1, -1, -1], [8, -1, -1, -1], [10, -1, -1, -1]]], - [[]], # TODO + [[[0, 1, 2, 4, 5, 7, 8, 10]], [[3, 1, 2, 4, 5, 7, 8, 10]], [[6, 1, 2, 4, 5, 7, 8, 10]], + [[9, 1, 2, 4, 5, 7, 8, 10]]], ), # One logical expert is put multiple times on ONE gpu ( [[[0, 1, 2], [3, -1, -1], [4, -1, -1], [5, -1, -1], [6, -1, -1], [7, -1, -1], [8, -1, -1], [9, -1, -1]]], - [[]], # TODO + [[[0, 3, 4, 5, 6, 7, 8, 9]], [[1, 3, 4, 5, 6, 7, 8, 9]], [[0, 3, 4, 5, 6, 7, 8, 9]], + [[2, 3, 4, 5, 6, 7, 8, 9]]], ), # Random ( [[[4, 11, -1], [5, 9, 0], [6, -1, -1], [8, -1, -1], [1, -1, -1], [10, -1, -1], [2, 3, -1], [7, -1, -1]]], - [[]], # TODO + [[[11, 0, 6, 8, 1, 10, 2, 7]], [[4, 5, 6, 8, 1, 10, 3, 7]], [[4, 5, 6, 8, 1, 10, 2, 7]], + [[11, 9, 6, 8, 1, 10, 3, 7]]], ), ] From aa1d88771b4a260d820e2b7d7a0376b8ccbcc55e Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 15:31:48 +0800 Subject: [PATCH 1012/1089] more --- test/srt/test_expert_location.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/srt/test_expert_location.py b/test/srt/test_expert_location.py index 11cbe6372a3..12e677901f7 100644 --- a/test/srt/test_expert_location.py +++ b/test/srt/test_expert_location.py @@ -27,14 +27,14 @@ def test_compute_logical_to_rank_dispatch_physical_map(self): ), # One logical expert is put on ALL gpus ( - [[[0, 3, 6, 9], [1, -1, -1, -1], [2, -1, -1, -1], [4, -1, -1, -1], [5, -1, -1, -1], [7, -1, -1, -1], + [[[3, 9, 6, 0], [1, -1, -1, -1], [2, -1, -1, -1], [4, -1, -1, -1], [5, -1, -1, -1], [7, -1, -1, -1], [8, -1, -1, -1], [10, -1, -1, -1]]], [[[0, 1, 2, 4, 5, 7, 8, 10]], [[3, 1, 2, 4, 5, 7, 8, 10]], [[6, 1, 2, 4, 5, 7, 8, 10]], [[9, 1, 2, 4, 5, 7, 8, 10]]], ), # One logical expert is put multiple times on ONE gpu ( - [[[0, 1, 2], [3, -1, -1], [4, -1, -1], [5, -1, -1], [6, -1, -1], [7, -1, -1], [8, -1, -1], + [[[2, 0, 1], [3, -1, -1], [4, -1, -1], [5, -1, -1], [6, -1, -1], [7, -1, -1], [8, -1, -1], [9, -1, -1]]], [[[0, 3, 4, 5, 6, 7, 8, 9]], [[1, 3, 4, 5, 6, 7, 8, 9]], [[0, 3, 4, 5, 6, 7, 8, 9]], [[2, 3, 4, 5, 6, 7, 8, 9]]], From c35fdcbc1a9f31df25daabbc205cf1c4ee10f4bf Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 15:32:29 +0800 Subject: [PATCH 1013/1089] more --- python/sglang/srt/managers/expert_distribution.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index a750f4b6c88..959e22d58d0 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -466,7 +466,7 @@ def postprocess_dumps( dumps: List[Any], expert_location_metadata: "ExpertLocationMetadata", ): - logical_count = torch.stack([item["logical_count"] for item in dumps]).sum(dim=0) + logical_count = torch.tensor([item["logical_count"] for item in dumps]).sum(dim=0) return dict(logical_count=logical_count.tolist()) def __init__(self, expert_location_metadata: "ExpertLocationMetadata", rank: int): From 55344c5a80973f4961b0b00a863a9fec935dc350 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 15:33:23 +0800 Subject: [PATCH 1014/1089] more --- test/srt/test_expert_location.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/srt/test_expert_location.py b/test/srt/test_expert_location.py index 12e677901f7..6ed81dfae83 100644 --- a/test/srt/test_expert_location.py +++ b/test/srt/test_expert_location.py @@ -36,8 +36,8 @@ def test_compute_logical_to_rank_dispatch_physical_map(self): ( [[[2, 0, 1], [3, -1, -1], [4, -1, -1], [5, -1, -1], [6, -1, -1], [7, -1, -1], [8, -1, -1], [9, -1, -1]]], - [[[0, 3, 4, 5, 6, 7, 8, 9]], [[1, 3, 4, 5, 6, 7, 8, 9]], [[0, 3, 4, 5, 6, 7, 8, 9]], - [[2, 3, 4, 5, 6, 7, 8, 9]]], + [[[2, 3, 4, 5, 6, 7, 8, 9]], [[0, 3, 4, 5, 6, 7, 8, 9]], [[2, 3, 4, 5, 6, 7, 8, 9]], + [[1, 3, 4, 5, 6, 7, 8, 9]]], ), # Random ( From 3c67506e39ab477c04a75f4e07c36793c54721f8 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 15:36:21 +0800 Subject: [PATCH 1015/1089] more --- test/srt/test_eplb.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index 20e6f50a3af..6d4926a69a6 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -9,7 +9,6 @@ import numpy as np import sglang as sgl import torch -from python.sglang.srt.managers.expert_location import compute_logical_to_rank_dispatch_physical_map from sglang.srt.managers.expert_distribution_storage import ExpertDistributionStorage from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( @@ -250,8 +249,8 @@ def test_nontrivial_location(self): offset = 3 physical_to_logical_map = ( - offset - + torch.arange(0, _NUM_ROUTED_EXPERTS + ep_num_redundant_experts).repeat( + offset + + torch.arange(0, _NUM_ROUTED_EXPERTS + ep_num_redundant_experts).repeat( _NUM_HIDDEN_LAYERS, 1 ) ) % _NUM_ROUTED_EXPERTS From a38b0ae22f0534b91f3aa9c3c9b8d5c39c7c0a86 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 15:39:59 +0800 Subject: [PATCH 1016/1089] more --- python/sglang/srt/layers/moe/ep_moe/kernels.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/kernels.py b/python/sglang/srt/layers/moe/ep_moe/kernels.py index d59eeeaa45d..d1a2dee0ade 100644 --- a/python/sglang/srt/layers/moe/ep_moe/kernels.py +++ b/python/sglang/srt/layers/moe/ep_moe/kernels.py @@ -4,7 +4,6 @@ import torch import triton import triton.language as tl - from sglang.srt.distributed import get_tensor_model_parallel_rank from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8 from sglang.srt.utils import ( @@ -444,13 +443,13 @@ def gelu_and_mul_triton_kernel( * ( 1 + tanh( - kAlpha - * ( - gate_output - + 0.044715 * gate_output * gate_output * gate_output - ) + kAlpha + * ( + gate_output + + 0.044715 * gate_output * gate_output * gate_output ) ) + ) ) gate_output = gate_output.to(InDtype) @@ -674,6 +673,9 @@ def grouped_gemm_triton( assert triton.cdiv(a.shape[-1], block_k) == scale_a.shape[-1] assert triton.cdiv(b.shape[-2], block_n) == scale_b.shape[-2] assert triton.cdiv(b.shape[-1], block_k) == scale_b.shape[-1] + else: + # TODO temp, will refactor + a = DisposibleTensor.maybe_unwrap(a) # TODO: adjust config or tune kernel # Reduce block size to prevent L40 shared memory overflow. From 2f09e4d00ddc6c3bbe1f5f0fefc165d027fe0e72 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 15:49:31 +0800 Subject: [PATCH 1017/1089] more --- python/sglang/srt/managers/expert_distribution.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 959e22d58d0..89915f6169f 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -474,7 +474,7 @@ def __init__(self, expert_location_metadata: "ExpertLocationMetadata", rank: int self._buffer_global_physical_count = torch.zeros( ( self._expert_location_metadata.num_layers, - self._expert_location_metadata.num_local_physical_experts, + self._expert_location_metadata.num_physical_experts, ) ) self._logical_count = torch.zeros( From f473f290529edbe4058c5e77259b6d1e8a8cba5c Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 15:56:41 +0800 Subject: [PATCH 1018/1089] more --- python/sglang/srt/managers/scheduler.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 7ab23c878ff..11e385a6266 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -1809,7 +1809,6 @@ def _pause_engine(self) -> Tuple[List[Req], int]: def update_expert_location(self, recv_req: UpdateExpertLocationReqInput): self.tp_worker.worker.model_runner.update_expert_location(recv_req) - return UpdateExpertLocationReqOutput() def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput): """In-place update of the weights from disk.""" From 1bdd7dd7e1d3be9e4188704ce4cf37612e796c67 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 15:57:03 +0800 Subject: [PATCH 1019/1089] more --- python/sglang/srt/managers/scheduler.py | 2 +- python/sglang/srt/model_executor/model_runner.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 11e385a6266..2537f5a2c60 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -1808,7 +1808,7 @@ def _pause_engine(self) -> Tuple[List[Req], int]: raise NotImplementedError() def update_expert_location(self, recv_req: UpdateExpertLocationReqInput): - self.tp_worker.worker.model_runner.update_expert_location(recv_req) + self.tp_worker.worker.model_runner.update_expert_location_start(recv_req) def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput): """In-place update of the weights from disk.""" diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 1d36d7717d4..7ee67c59c73 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -499,7 +499,7 @@ def load_model(self): f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node." ) from None - def update_expert_location(self, recv_req: UpdateExpertLocationReqInput): + def update_expert_location_start(self, recv_req: UpdateExpertLocationReqInput): self._expert_location_updater.start(recv_req) def event_loop_step(self) -> List[Any]: From 547804b6bf9120f5fb5bc4b2775994d5f8dcd8fe Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 16:02:01 +0800 Subject: [PATCH 1020/1089] more --- python/sglang/srt/managers/tokenizer_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 54e9aba260c..a4e02d46b31 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -714,7 +714,7 @@ async def update_expert_location(self, obj: UpdateExpertLocationReqInput): layer_id_lens = list(range(0, num_layers, 10)) + [num_layers] for layer_id_end in layer_id_lens: - logger.info(f"update_expert_location handling 0~{layer_id_end} layers") + logger.info(f"update_expert_location handling up to {layer_id_end}th layer") partial_expert_location_metadata = copy.deepcopy(old_expert_location_metadata) partial_expert_location_metadata.update(obj.expert_location_metadata, layer_id_start=0, layer_id_len=layer_id_end) From cfb7672d3486313e5fe02ed634ecfd218e3e8c1b Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 16:04:27 +0800 Subject: [PATCH 1021/1089] more --- test/srt/test_eplb.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index 6d4926a69a6..e892993f7ca 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -73,6 +73,7 @@ async def _task_generate(): tasks = [] async for content in _yield_with_poisson_process(contents_duplicated, action_rate=request_rate): + print("Action: start async_generate") tasks.append(asyncio.create_task(engine.async_generate( prompt=content["prompt"], sampling_params=dict(temperature=0, max_new_tokens=4), @@ -90,7 +91,10 @@ async def _task_rebalance(): return for i in range(num_rebalance): + print("Action: start eplb_rebalance") await engine.tokenizer_manager.eplb_rebalance() + print("Action: end eplb_rebalance") + await asyncio.sleep(1.0) with tempfile.TemporaryDirectory() as tmpdir: engine_kwargs = dict( From 0ff54bd4ad7c0f4224936089608eb2859c526120 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 16:06:30 +0800 Subject: [PATCH 1022/1089] more --- test/srt/test_eplb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index e892993f7ca..434bb039ac7 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -119,7 +119,7 @@ async def _task_rebalance(): engine = sgl.Engine(**engine_kwargs) loop = asyncio.get_event_loop() - loop.run_until_complete(_main_async) + loop.run_until_complete(_main_async()) print(f"Action: shutdown engine") engine.shutdown() From 950054b0d21db7e70c46dadcff91831797d18b07 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 16:09:59 +0800 Subject: [PATCH 1023/1089] more --- test/srt/test_eplb.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index 434bb039ac7..aee513b8cc3 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -41,23 +41,23 @@ def _test_eplb_many_rebalances_core(self, enable_eplb: bool = True): contents_raw = [ dict( prompt="1+1=2, 1+2=3, 1+3=4, 1+4=5, 1+5=", - expect_output="TODO", + expect_output='6, 1', ), dict( prompt="2*1=2, 2*2=4, 2*3=6, 2*4=", - expect_output="TODO", + expect_output='8, 2', ), dict( prompt="10*1=10, 10*2=20, 10*3=30, 10*4=40, 10*5=50, 10*6=", - expect_output="TODO", + expect_output='60, ', ), dict( prompt="2/2=1, 4/2=2, 6/2=3, 8/2=", - expect_output="TODO", + expect_output='4, 1', ), dict( prompt="One plus one is two, one plus two is three, one plus three is", - expect_output="TODO", + expect_output=' four, one plus', ), ] From 2a6fe7c431b10b3ae848eadd5588b1b673ab83a9 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 16:11:45 +0800 Subject: [PATCH 1024/1089] more --- test/srt/test_eplb.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index aee513b8cc3..f60d6f676a4 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -2,6 +2,7 @@ import json import random import tempfile +import time import unittest from pathlib import Path from typing import List @@ -73,7 +74,7 @@ async def _task_generate(): tasks = [] async for content in _yield_with_poisson_process(contents_duplicated, action_rate=request_rate): - print("Action: start async_generate") + print(f"[{time.time()}] Action: start async_generate") tasks.append(asyncio.create_task(engine.async_generate( prompt=content["prompt"], sampling_params=dict(temperature=0, max_new_tokens=4), @@ -91,9 +92,9 @@ async def _task_rebalance(): return for i in range(num_rebalance): - print("Action: start eplb_rebalance") + print(f"[{time.time()}] Action: start eplb_rebalance") await engine.tokenizer_manager.eplb_rebalance() - print("Action: end eplb_rebalance") + print(f"[{time.time()}] Action: end eplb_rebalance") await asyncio.sleep(1.0) with tempfile.TemporaryDirectory() as tmpdir: From a0842dcfa4282e9701f7ef7005ea7ff0071662a5 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 16:12:27 +0800 Subject: [PATCH 1025/1089] more --- test/srt/test_eplb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index f60d6f676a4..a3f3e355246 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -38,7 +38,7 @@ def _test_eplb_many_rebalances_core(self, enable_eplb: bool = True): num_rebalance = 20 request_rate = 20 - content_duplicate_num = 20 + content_duplicate_num = 200 contents_raw = [ dict( prompt="1+1=2, 1+2=3, 1+3=4, 1+4=5, 1+5=", From 3d29e508ad635353e74df0e0e42d146db8a39589 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 16:12:55 +0800 Subject: [PATCH 1026/1089] more --- test/srt/test_eplb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index a3f3e355246..75a30cc0a90 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -36,7 +36,7 @@ def test_eplb_many_rebalances_baseline(self): def _test_eplb_many_rebalances_core(self, enable_eplb: bool = True): print("Action: test_eplb_many_rebalances") - num_rebalance = 20 + num_rebalance = 10 request_rate = 20 content_duplicate_num = 200 contents_raw = [ From 917f22e1b35fdae8170f0881ed8151aaaa144ffc Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 16:16:08 +0800 Subject: [PATCH 1027/1089] more --- python/sglang/srt/managers/tokenizer_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index a4e02d46b31..2737bb7fc48 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -711,7 +711,7 @@ async def update_expert_location(self, obj: UpdateExpertLocationReqInput): old_expert_location_metadata = copy.deepcopy(self.expert_location_metadata) num_layers = old_expert_location_metadata.num_layers - layer_id_lens = list(range(0, num_layers, 10)) + [num_layers] + layer_id_lens = list(range(10, num_layers, 10)) + [num_layers] for layer_id_end in layer_id_lens: logger.info(f"update_expert_location handling up to {layer_id_end}th layer") From 198f327891ee10ec9cbcd2954987369d5ebc975b Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 16:16:28 +0800 Subject: [PATCH 1028/1089] more --- python/sglang/srt/managers/tokenizer_manager.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 2737bb7fc48..e0e957f3544 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -711,6 +711,7 @@ async def update_expert_location(self, obj: UpdateExpertLocationReqInput): old_expert_location_metadata = copy.deepcopy(self.expert_location_metadata) num_layers = old_expert_location_metadata.num_layers + # pretty arbitrary choice; can optimize if bottleneck layer_id_lens = list(range(10, num_layers, 10)) + [num_layers] for layer_id_end in layer_id_lens: From 2fc6df17c635657e4cbded0c4c5deb8452f115bd Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 16:17:25 +0800 Subject: [PATCH 1029/1089] more --- test/srt/test_eplb.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index 75a30cc0a90..7ea00c3aa93 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -92,9 +92,9 @@ async def _task_rebalance(): return for i in range(num_rebalance): - print(f"[{time.time()}] Action: start eplb_rebalance") + print(f"[{time.time()}] Action: start eplb_rebalance {i}") await engine.tokenizer_manager.eplb_rebalance() - print(f"[{time.time()}] Action: end eplb_rebalance") + print(f"[{time.time()}] Action: end eplb_rebalance {i}") await asyncio.sleep(1.0) with tempfile.TemporaryDirectory() as tmpdir: From 31bef4e7d75150d2675bb9bb0c79beb432d352f3 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 16:20:36 +0800 Subject: [PATCH 1030/1089] more --- test/srt/test_eplb.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index 7ea00c3aa93..0a76bb37571 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -114,6 +114,7 @@ async def _task_rebalance(): dp_size=2, log_level="info", disable_radix_cache=True, + mem_fraction_static=0.8, ) print(f"Action: start engine") From 09db1a1d0fbda42a699c88ff184362664312f1df Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 16:54:22 +0800 Subject: [PATCH 1031/1089] more --- python/sglang/srt/server_args.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index e0f91ab3b37..042b03717bf 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -338,7 +338,9 @@ def __post_init__(self): if self.enable_eplb: self.enable_expert_distribution_recorder = True - logger.info(f"EPLB is enabled. The enable_expert_distribution_recorder is automatically enabled.") + self.disable_overlap_schedule = True + logger.info( + f"EPLB is enabled. The enable_expert_distribution_recorder and disable_overlap_schedule is automatically set.") if self.enable_eplb or (self.init_expert_location is not None): self.ep_dispatch_algorithm = "static" logger.info( From 1a6e5cac01393090c385a407f823c26b34355017 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 16:54:51 +0800 Subject: [PATCH 1032/1089] more --- python/sglang/srt/model_executor/expert_location_updater.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/sglang/srt/model_executor/expert_location_updater.py b/python/sglang/srt/model_executor/expert_location_updater.py index 23574f8b8f4..45d21709891 100644 --- a/python/sglang/srt/model_executor/expert_location_updater.py +++ b/python/sglang/srt/model_executor/expert_location_updater.py @@ -29,6 +29,9 @@ def __init__(self, model_runner: "ModelRunner"): ) self._prepare_end_barrier = PollBasedBarrier(noop=False) self._ongoing_req: Optional[UpdateExpertLocationReqInput] = None + assert ( + model_runner.server_args.disable_overlap_schedule + ), "ExpertLocationUpdater requires overlap scheduler to be disabled" def start(self, req: UpdateExpertLocationReqInput): assert self._ongoing_req is None From a96d4830ab234631837fcede7d9609c406074159 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 16:55:55 +0800 Subject: [PATCH 1033/1089] more --- python/sglang/srt/model_executor/model_runner.py | 2 +- python/sglang/srt/server_args.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 7ee67c59c73..25d12f79289 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -211,7 +211,7 @@ def __init__( # If it is a draft model tp_group can be different. self.initialize(min_per_gpu_memory) - self._expert_location_updater = ExpertLocationUpdater(self) + self._expert_location_updater = ExpertLocationUpdater(self) if server_args.expert_location_updater_mode is not None else None def initialize(self, min_per_gpu_memory: float): server_args = self.server_args diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 042b03717bf..cfb1d67696c 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -165,7 +165,7 @@ class ServerArgs: ep_num_redundant_experts: int = 0 ep_dispatch_algorithm: Optional[Literal["static", "random"]] = None init_expert_location: Optional[str] = None - expert_location_updater_mode: Literal["pin_memory", "pageable_memory"] = "pin_memory" + expert_location_updater_mode: Optional[Literal["pin_memory", "pageable_memory"]] = None enable_eplb: bool = False eplb_storage_dir: str = "/tmp/eplb_storage" eplb_rebalance_period: Optional[int] = None From 3cda0fbda130492bf4668ab7e763f01b88b49593 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 16:56:00 +0800 Subject: [PATCH 1034/1089] more --- python/sglang/srt/model_executor/model_runner.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 25d12f79289..0fa690126de 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -211,7 +211,8 @@ def __init__( # If it is a draft model tp_group can be different. self.initialize(min_per_gpu_memory) - self._expert_location_updater = ExpertLocationUpdater(self) if server_args.expert_location_updater_mode is not None else None + self._expert_location_updater = ExpertLocationUpdater( + self) if server_args.expert_location_updater_mode is not None else None def initialize(self, min_per_gpu_memory: float): server_args = self.server_args From 68b7aa0dcba113e25d2aea16beed28981d2b5819 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 16:56:22 +0800 Subject: [PATCH 1035/1089] more --- python/sglang/srt/model_executor/expert_location_updater.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/model_executor/expert_location_updater.py b/python/sglang/srt/model_executor/expert_location_updater.py index 45d21709891..70b339b8fb5 100644 --- a/python/sglang/srt/model_executor/expert_location_updater.py +++ b/python/sglang/srt/model_executor/expert_location_updater.py @@ -21,7 +21,10 @@ class ExpertLocationUpdater: def __init__(self, model_runner: "ModelRunner"): self._model_runner = model_runner self._model_weight_updater = ModelWeightUpdater( - init_pin_memory=model_runner.server_args.expert_location_updater_mode == "pin_memory", + init_pin_memory={ + "pin_memory": True, + "pageable_memory": False, + }[model_runner.server_args.expert_location_updater_mode], load_format=model_runner.server_args.load_format, model_config=model_runner.model_config, model=model_runner.model, From 80837da5d763307b34456af17ddc960920b29bd1 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 16:57:17 +0800 Subject: [PATCH 1036/1089] more --- python/sglang/srt/server_args.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index cfb1d67696c..63e64c57ed3 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -342,9 +342,12 @@ def __post_init__(self): logger.info( f"EPLB is enabled. The enable_expert_distribution_recorder and disable_overlap_schedule is automatically set.") if self.enable_eplb or (self.init_expert_location is not None): - self.ep_dispatch_algorithm = "static" + if self.ep_dispatch_algorithm is None: + self.ep_dispatch_algorithm = "static" + if self.expert_location_updater_mode is None: + self.expert_location_updater_mode = "pin_memory" logger.info( - f"EPLB is enabled or init_expert_location is provided. ep_dispatch_algorithm is set to `static`.") + f"EPLB is enabled or init_expert_location is provided. ep_dispatch_algorithm and expert_location_updater_mode are configured.") if self.ep_num_redundant_experts > 0: assert ( From ee8ffc1a2e6382bcac10e778f9cb7a3f6b0ac1f9 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 16:57:43 +0800 Subject: [PATCH 1037/1089] more --- python/sglang/srt/server_args.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 63e64c57ed3..c37ec2432ee 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -339,15 +339,14 @@ def __post_init__(self): if self.enable_eplb: self.enable_expert_distribution_recorder = True self.disable_overlap_schedule = True + if self.expert_location_updater_mode is None: + self.expert_location_updater_mode = "pin_memory" logger.info( - f"EPLB is enabled. The enable_expert_distribution_recorder and disable_overlap_schedule is automatically set.") + f"EPLB is enabled. The enable_expert_distribution_recorder, disable_overlap_schedule and expert_location_updater_mode are automatically set.") if self.enable_eplb or (self.init_expert_location is not None): if self.ep_dispatch_algorithm is None: self.ep_dispatch_algorithm = "static" - if self.expert_location_updater_mode is None: - self.expert_location_updater_mode = "pin_memory" - logger.info( - f"EPLB is enabled or init_expert_location is provided. ep_dispatch_algorithm and expert_location_updater_mode are configured.") + logger.info(f"EPLB is enabled or init_expert_location is provided. ep_dispatch_algorithm is configured.") if self.ep_num_redundant_experts > 0: assert ( From d6211fcb5865a0d8b8e0d828a6e164fcbf033d6b Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 17:01:07 +0800 Subject: [PATCH 1038/1089] more --- python/sglang/srt/server_args.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index c37ec2432ee..8c984783f01 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -338,11 +338,14 @@ def __post_init__(self): if self.enable_eplb: self.enable_expert_distribution_recorder = True - self.disable_overlap_schedule = True if self.expert_location_updater_mode is None: self.expert_location_updater_mode = "pin_memory" logger.info( - f"EPLB is enabled. The enable_expert_distribution_recorder, disable_overlap_schedule and expert_location_updater_mode are automatically set.") + f"EPLB is enabled. The enable_expert_distribution_recorder and expert_location_updater_mode are automatically set.") + if self.expert_location_updater_mode is None: + self.disable_overlap_schedule = True + logger.info( + f"ExpertLocationUpdater is enabled. The disable_overlap_schedule is set.") if self.enable_eplb or (self.init_expert_location is not None): if self.ep_dispatch_algorithm is None: self.ep_dispatch_algorithm = "static" From 5e3ee7e93f5d8199b73cadf4b25ba075c84ca831 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 17:01:15 +0800 Subject: [PATCH 1039/1089] more --- python/sglang/srt/server_args.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 8c984783f01..68a8e005a4e 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -342,7 +342,7 @@ def __post_init__(self): self.expert_location_updater_mode = "pin_memory" logger.info( f"EPLB is enabled. The enable_expert_distribution_recorder and expert_location_updater_mode are automatically set.") - if self.expert_location_updater_mode is None: + if self.expert_location_updater_mode is not None: self.disable_overlap_schedule = True logger.info( f"ExpertLocationUpdater is enabled. The disable_overlap_schedule is set.") From c1410c46de7f8b366c2b002aebce01c6e094f270 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 17:46:03 +0800 Subject: [PATCH 1040/1089] more --- python/sglang/srt/layers/moe/topk.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 3464bb83104..b954207319b 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -159,6 +159,7 @@ def biased_grouped_topk_impl( num_expert_group: int = 0, topk_group: int = 0, n_share_experts_fusion: int = 0, + expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, ): assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" @@ -207,7 +208,11 @@ def biased_grouped_topk_impl( ) topk_weights = topk_weights / topk_weights_sum - return topk_weights.to(torch.float32), topk_ids.to(torch.int32) + topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32) + + topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info) + + return topk_weights, topk_ids def biased_grouped_topk( @@ -220,6 +225,7 @@ def biased_grouped_topk( topk_group: int = 0, compiled: bool = True, n_share_experts_fusion: int = 0, + expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, ): biased_grouped_topk_fn = ( torch.compile( @@ -237,6 +243,7 @@ def biased_grouped_topk( num_expert_group, topk_group, n_share_experts_fusion=n_share_experts_fusion, + expert_location_dispatch_info=expert_location_dispatch_info, ) @@ -272,7 +279,6 @@ def select_experts( expert_location_dispatch_info=expert_location_dispatch_info, ) else: - assert expert_location_dispatch_info is None topk_weights, topk_ids = biased_grouped_topk( hidden_states=hidden_states, gating_output=router_logits, @@ -282,6 +288,7 @@ def select_experts( num_expert_group=num_expert_group, topk_group=topk_group, n_share_experts_fusion=n_share_experts_fusion, + expert_location_dispatch_info=expert_location_dispatch_info, ) elif torch_native and custom_routing_function is None: assert expert_location_dispatch_info is None From cc0d5d367ccc4d0eacf99753d5c2884d585af6a0 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 17:53:50 +0800 Subject: [PATCH 1041/1089] more --- python/sglang/srt/managers/tokenizer_manager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index e0e957f3544..39041f721c7 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -705,8 +705,8 @@ async def eplb_save_expert_distribution(self): async def update_expert_location(self, obj: UpdateExpertLocationReqInput): self.auto_create_handle_loop() assert ( - self.server_args.enable_scheduler_input_blocker and (self.server_args.ep_dispatch_algorithm is not None) - ), f"update_expert_location requires enable_scheduler_input_blocker and non-null ep_dispatch_algorithm" + self.server_args.ep_dispatch_algorithm is not None + ), f"update_expert_location requires ep_dispatch_algorithm" old_expert_location_metadata = copy.deepcopy(self.expert_location_metadata) num_layers = old_expert_location_metadata.num_layers From 80dbbc37980018874775e08f8b0190acc0b1f4fe Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 17:57:17 +0800 Subject: [PATCH 1042/1089] more --- python/sglang/srt/model_executor/expert_location_updater.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/model_executor/expert_location_updater.py b/python/sglang/srt/model_executor/expert_location_updater.py index 70b339b8fb5..6f431114a37 100644 --- a/python/sglang/srt/model_executor/expert_location_updater.py +++ b/python/sglang/srt/model_executor/expert_location_updater.py @@ -87,7 +87,7 @@ def _weight_filter(self, name: str, interesting_logical_experts_of_layer: Dict[i info: ModelParamNameInfo = self._model_runner.model.get_param_name_info(name) return ( isinstance(info, ModelParamNameInfoMoe) - and (info.expert_id in interesting_logical_experts_of_layer[info.layer_id]) + and (info.expert_id in interesting_logical_experts_of_layer.get(info.layer_id, [])) ) From e26de0cc75adf5f19c369768efccfb8d75c4de17 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 18:15:11 +0800 Subject: [PATCH 1043/1089] more --- .../sglang/srt/model_executor/model_weight_updater.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/python/sglang/srt/model_executor/model_weight_updater.py b/python/sglang/srt/model_executor/model_weight_updater.py index 56326a0c3ea..9c9334f1399 100644 --- a/python/sglang/srt/model_executor/model_weight_updater.py +++ b/python/sglang/srt/model_executor/model_weight_updater.py @@ -1,3 +1,5 @@ +import datetime +import logging from abc import ABC from dataclasses import dataclass from typing import Tuple, List, Callable, Iterable @@ -9,6 +11,8 @@ from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader from sglang.srt.model_loader.utils import set_default_torch_dtype +logger = logging.getLogger(__name__) + class ModelWeightUpdater: def __init__( @@ -31,6 +35,7 @@ def __init__( self._state: _State = _StateIdle() def start_prepare(self, weight_filter: Callable[[str], bool]): + _log_with_accurate_time("ModelWeightUpdater.start_prepare start") assert isinstance(self._state, _StateIdle) all_weights_iterator = self._model_weight_source.get_all_weights() @@ -38,6 +43,7 @@ def start_prepare(self, weight_filter: Callable[[str], bool]): self._memory_transfer_manager.enqueue(interesting_weights) self._state = _StateAwaitMemoryTransfer() + _log_with_accurate_time("ModelWeightUpdater.start_prepare end") def poll_prepare_end(self): memory_transfer_outputs = self._memory_transfer_manager.get_outputs() @@ -65,6 +71,10 @@ def act(self): self._state = _StateIdle() +def _log_with_accurate_time(message): + logger.info(f"[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')}] {message}") + + class _State(ABC): pass From 9048c0724fc25811ba28540bb1244a6b7033a2c7 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 18:15:32 +0800 Subject: [PATCH 1044/1089] more --- python/sglang/srt/model_executor/model_weight_updater.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/sglang/srt/model_executor/model_weight_updater.py b/python/sglang/srt/model_executor/model_weight_updater.py index 9c9334f1399..3d6212b98d7 100644 --- a/python/sglang/srt/model_executor/model_weight_updater.py +++ b/python/sglang/srt/model_executor/model_weight_updater.py @@ -55,10 +55,12 @@ def poll_prepare_end(self): return True def _handle_memory_transfer_output(self, memory_transfer_output): + _log_with_accurate_time("ModelWeightUpdater.handle_memory_transfer_output") assert isinstance(self._state, _StateAwaitMemoryTransfer) self._state = _StatePrepared(named_tensors=memory_transfer_output) def act(self): + _log_with_accurate_time("ModelWeightUpdater.act start") assert isinstance(self._state, _StatePrepared) target_device = torch.device(self._device) @@ -69,6 +71,7 @@ def act(self): DefaultModelLoader.load_weights_and_postprocess(self._model, named_tensors, target_device) self._state = _StateIdle() + _log_with_accurate_time("ModelWeightUpdater.act end") def _log_with_accurate_time(message): From 03e0a0c5d1912effbcb924c780b89b653f3990d5 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 18:17:09 +0800 Subject: [PATCH 1045/1089] more --- .../srt/model_executor/expert_location_updater.py | 11 +++++++++++ .../sglang/srt/model_executor/model_weight_updater.py | 10 ---------- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/python/sglang/srt/model_executor/expert_location_updater.py b/python/sglang/srt/model_executor/expert_location_updater.py index 6f431114a37..226e1934ec5 100644 --- a/python/sglang/srt/model_executor/expert_location_updater.py +++ b/python/sglang/srt/model_executor/expert_location_updater.py @@ -37,6 +37,7 @@ def __init__(self, model_runner: "ModelRunner"): ), "ExpertLocationUpdater requires overlap scheduler to be disabled" def start(self, req: UpdateExpertLocationReqInput): + _log_with_accurate_time("ExpertLocationUpdater.start begin") assert self._ongoing_req is None self._ongoing_req = req @@ -49,11 +50,13 @@ def start(self, req: UpdateExpertLocationReqInput): self._model_weight_updater.start_prepare( weight_filter=lambda name: self._weight_filter(name, interesting_logical_experts_of_layer), ) + _log_with_accurate_time("ExpertLocationUpdater.start end") def event_loop_step(self) -> List[UpdateExpertLocationReqOutput]: outputs = [] if self._model_weight_updater.poll_prepare_end(): + _log_with_accurate_time("ExpertLocationUpdater.event_loop_step observe local_arrive") self._prepare_end_barrier.local_arrive() if self._prepare_end_barrier.poll_global_arrived(): @@ -62,6 +65,7 @@ def event_loop_step(self) -> List[UpdateExpertLocationReqOutput]: return outputs def _act(self): + _log_with_accurate_time("ExpertLocationUpdater.act start") torch.distributed.barrier() get_global_expert_distribution_recorder().flush_buffer_depending_on_expert_location_metadata() @@ -74,13 +78,16 @@ def _act(self): f"Updated expert_location_metadata: {get_global_expert_location_metadata().debug_str()}" ) + _log_with_accurate_time("ExpertLocationUpdater.act execute ModelWeightUpdater.act start") self._model_weight_updater.act() + _log_with_accurate_time("ExpertLocationUpdater.act execute ModelWeightUpdater.act end") torch.distributed.barrier() assert self._ongoing_req is not None self._ongoing_req = None + _log_with_accurate_time("ExpertLocationUpdater.act end") return UpdateExpertLocationReqOutput() def _weight_filter(self, name: str, interesting_logical_experts_of_layer: Dict[int, List[int]]): @@ -109,3 +116,7 @@ def _get_partial_physical_to_logical_map(meta: ExpertLocationMetadata, layer_id: new_partial_map = _get_partial_physical_to_logical_map(new_expert_location_metadata, layer_id) interesting_logical_experts_of_layer[layer_id] = new_partial_map[new_partial_map != old_partial_map].tolist() return interesting_logical_experts_of_layer + + +def _log_with_accurate_time(message): + logger.info(f"[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')}] {message}") diff --git a/python/sglang/srt/model_executor/model_weight_updater.py b/python/sglang/srt/model_executor/model_weight_updater.py index 3d6212b98d7..5016319755e 100644 --- a/python/sglang/srt/model_executor/model_weight_updater.py +++ b/python/sglang/srt/model_executor/model_weight_updater.py @@ -1,4 +1,3 @@ -import datetime import logging from abc import ABC from dataclasses import dataclass @@ -35,7 +34,6 @@ def __init__( self._state: _State = _StateIdle() def start_prepare(self, weight_filter: Callable[[str], bool]): - _log_with_accurate_time("ModelWeightUpdater.start_prepare start") assert isinstance(self._state, _StateIdle) all_weights_iterator = self._model_weight_source.get_all_weights() @@ -43,7 +41,6 @@ def start_prepare(self, weight_filter: Callable[[str], bool]): self._memory_transfer_manager.enqueue(interesting_weights) self._state = _StateAwaitMemoryTransfer() - _log_with_accurate_time("ModelWeightUpdater.start_prepare end") def poll_prepare_end(self): memory_transfer_outputs = self._memory_transfer_manager.get_outputs() @@ -55,12 +52,10 @@ def poll_prepare_end(self): return True def _handle_memory_transfer_output(self, memory_transfer_output): - _log_with_accurate_time("ModelWeightUpdater.handle_memory_transfer_output") assert isinstance(self._state, _StateAwaitMemoryTransfer) self._state = _StatePrepared(named_tensors=memory_transfer_output) def act(self): - _log_with_accurate_time("ModelWeightUpdater.act start") assert isinstance(self._state, _StatePrepared) target_device = torch.device(self._device) @@ -71,11 +66,6 @@ def act(self): DefaultModelLoader.load_weights_and_postprocess(self._model, named_tensors, target_device) self._state = _StateIdle() - _log_with_accurate_time("ModelWeightUpdater.act end") - - -def _log_with_accurate_time(message): - logger.info(f"[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')}] {message}") class _State(ABC): From 6ca241a393c0d95b20591baf069ee88b679af21f Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 18:21:34 +0800 Subject: [PATCH 1046/1089] more --- python/sglang/srt/model_executor/expert_location_updater.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/model_executor/expert_location_updater.py b/python/sglang/srt/model_executor/expert_location_updater.py index 226e1934ec5..12123163385 100644 --- a/python/sglang/srt/model_executor/expert_location_updater.py +++ b/python/sglang/srt/model_executor/expert_location_updater.py @@ -1,3 +1,4 @@ +import datetime import logging from typing import TYPE_CHECKING, Dict, List, Optional From b98d24840dda3a94842ddcb2f33b1c8785188e74 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 19:24:10 +0800 Subject: [PATCH 1047/1089] more --- python/sglang/srt/models/deepseek_v2.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index fc7ea7cbe99..232f5a79e6b 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1461,11 +1461,11 @@ def forward( input_ids, hidden_states, self.lm_head, forward_batch ) - def post_load_weights(self): + def post_load_weights(self, enable_mla_postprocess: bool = True): # Perform post-processing after loading weights - if not global_server_args_dict["disable_mla"]: + if enable_mla_postprocess and not global_server_args_dict["disable_mla"]: for layer_id in range(self.config.num_hidden_layers): self_attn = self.model.layers[layer_id].self_attn if hasattr(self_attn.kv_b_proj, "qweight"): @@ -1670,7 +1670,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ) weight_loader(param, loaded_weight) - self.post_load_weights() + self.post_load_weights(enable_mla_postprocess=TODO) def get_embed_and_head(self): return self.model.embed_tokens.weight, self.lm_head.weight From 2228a29d8eba604fa1505f467116740b56183b63 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 19:24:52 +0800 Subject: [PATCH 1048/1089] more --- python/sglang/srt/models/deepseek_v2.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 232f5a79e6b..fce64a9c4f5 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1610,7 +1610,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ) params_dict = dict(self.named_parameters()) + exist_mla_weights = False for name, loaded_weight in weights: + exist_mla_weights |= "self_attn" in name + # TODO(HandH1998): Modify it when nextn is supported. if hasattr(self.config, "num_nextn_predict_layers"): num_nextn_layers = self.config.num_nextn_predict_layers @@ -1670,7 +1673,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ) weight_loader(param, loaded_weight) - self.post_load_weights(enable_mla_postprocess=TODO) + self.post_load_weights(enable_mla_postprocess=exist_mla_weights) def get_embed_and_head(self): return self.model.embed_tokens.weight, self.lm_head.weight From d44b193e59138f06fe11d89b6a91de32e39d932b Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 19:25:11 +0800 Subject: [PATCH 1049/1089] more --- python/sglang/srt/model_executor/model_weight_updater.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/model_executor/model_weight_updater.py b/python/sglang/srt/model_executor/model_weight_updater.py index 5016319755e..016764aa7b0 100644 --- a/python/sglang/srt/model_executor/model_weight_updater.py +++ b/python/sglang/srt/model_executor/model_weight_updater.py @@ -98,12 +98,15 @@ def __init__(self, load_format: str, model_config: ModelConfig, model): self._model_config = model_config self._model = model - def get_all_weights(self) -> Iterable[Tuple[str, torch.Tensor]]: load_config = LoadConfig(load_format=self._load_format) loader = get_model_loader(load_config) assert isinstance(loader, DefaultModelLoader) with set_default_torch_dtype(self._model_config.dtype): - yield from loader._get_weights_iterator(DefaultModelLoader.Source.init_new(self._model_config, self._model)) + weights = list( + loader._get_weights_iterator(DefaultModelLoader.Source.init_new(self._model_config, self._model))) + + def get_all_weights(self) -> Iterable[Tuple[str, torch.Tensor]]: + return TODO class _ModelWeightSourcePinnedMemory(_ModelWeightSourceBase): From 07acebf2aba9fd3759a4eee618984773a83bb6aa Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 19:25:43 +0800 Subject: [PATCH 1050/1089] more --- .../model_executor/model_weight_updater.py | 26 +++++-------------- 1 file changed, 6 insertions(+), 20 deletions(-) diff --git a/python/sglang/srt/model_executor/model_weight_updater.py b/python/sglang/srt/model_executor/model_weight_updater.py index 016764aa7b0..e1543dbc5b8 100644 --- a/python/sglang/srt/model_executor/model_weight_updater.py +++ b/python/sglang/srt/model_executor/model_weight_updater.py @@ -87,26 +87,12 @@ class _StatePrepared(_State): named_tensors: List[Tuple[str, torch.Tensor]] -class _ModelWeightSourceBase(ABC): - def get_all_weights(self) -> Iterable[Tuple[str, torch.Tensor]]: - raise NotImplementedError - - -class _ModelWeightSourceVanilla(_ModelWeightSourceBase): - def __init__(self, load_format: str, model_config: ModelConfig, model): - self._load_format = load_format - self._model_config = model_config - self._model = model - - load_config = LoadConfig(load_format=self._load_format) - loader = get_model_loader(load_config) - assert isinstance(loader, DefaultModelLoader) - with set_default_torch_dtype(self._model_config.dtype): - weights = list( - loader._get_weights_iterator(DefaultModelLoader.Source.init_new(self._model_config, self._model))) - - def get_all_weights(self) -> Iterable[Tuple[str, torch.Tensor]]: - return TODO +def _get_all_weights_vanilla(load_format: str, model_config: ModelConfig, model): + load_config = LoadConfig(load_format=load_format) + loader = get_model_loader(load_config) + assert isinstance(loader, DefaultModelLoader) + with set_default_torch_dtype(model_config.dtype): + return list(loader._get_weights_iterator(DefaultModelLoader.Source.init_new(model_config, model))) class _ModelWeightSourcePinnedMemory(_ModelWeightSourceBase): From 649fe9dfd5e2cdb145e6c2d5b8cf0198e6c38124 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 19:26:58 +0800 Subject: [PATCH 1051/1089] more --- .../model_executor/model_weight_updater.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/python/sglang/srt/model_executor/model_weight_updater.py b/python/sglang/srt/model_executor/model_weight_updater.py index e1543dbc5b8..0464fa8c6d6 100644 --- a/python/sglang/srt/model_executor/model_weight_updater.py +++ b/python/sglang/srt/model_executor/model_weight_updater.py @@ -26,9 +26,8 @@ def __init__( self._model = model self._device = device - ModelWeightSourceCls = _ModelWeightSourcePinnedMemory if init_pin_memory else _ModelWeightSourceVanilla - self._model_weight_source = ModelWeightSourceCls(load_format=load_format, model_config=model_config, - model=model) + self._all_weights = _get_all_weights(load_format=load_format, model_config=model_config, model=model, + pin_memory=init_pin_memory) self._memory_transfer_manager = AsyncToCudaManager() if init_pin_memory else CombinedManager.init_pin_memory_and_to_cuda() self._state: _State = _StateIdle() @@ -87,21 +86,17 @@ class _StatePrepared(_State): named_tensors: List[Tuple[str, torch.Tensor]] -def _get_all_weights_vanilla(load_format: str, model_config: ModelConfig, model): +def _get_all_weights(load_format: str, model_config: ModelConfig, model, pin_memory: bool): load_config = LoadConfig(load_format=load_format) loader = get_model_loader(load_config) assert isinstance(loader, DefaultModelLoader) with set_default_torch_dtype(model_config.dtype): - return list(loader._get_weights_iterator(DefaultModelLoader.Source.init_new(model_config, model))) + all_weights = list(loader._get_weights_iterator(DefaultModelLoader.Source.init_new(model_config, model))) + if pin_memory: + all_weights = _named_tensors_pin_memory(all_weights) -class _ModelWeightSourcePinnedMemory(_ModelWeightSourceBase): - def __init__(self, *args, **kwargs): - vanilla = _ModelWeightSourceVanilla(*args, **kwargs) - self._all_weights = _named_tensors_pin_memory(list(vanilla.get_all_weights())) - - def get_all_weights(self) -> Iterable[Tuple[str, torch.Tensor]]: - return self._all_weights + return all_weights def _named_tensors_pin_memory(named_tensors: Iterable[Tuple[str, torch.Tensor]]) -> List[Tuple[str, torch.Tensor]]: From c1224c05e1148d4a0aadef34c5b4cdb9df5ffac4 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 19:32:24 +0800 Subject: [PATCH 1052/1089] more --- python/sglang/srt/model_executor/model_weight_updater.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/sglang/srt/model_executor/model_weight_updater.py b/python/sglang/srt/model_executor/model_weight_updater.py index 0464fa8c6d6..b14fcb8f8bb 100644 --- a/python/sglang/srt/model_executor/model_weight_updater.py +++ b/python/sglang/srt/model_executor/model_weight_updater.py @@ -35,8 +35,7 @@ def __init__( def start_prepare(self, weight_filter: Callable[[str], bool]): assert isinstance(self._state, _StateIdle) - all_weights_iterator = self._model_weight_source.get_all_weights() - interesting_weights = [(name, weight) for name, weight in all_weights_iterator if weight_filter(name)] + interesting_weights = [(name, weight) for name, weight in self._all_weights if weight_filter(name)] self._memory_transfer_manager.enqueue(interesting_weights) self._state = _StateAwaitMemoryTransfer() From 91fd636a90ca1df05740a92c09b94171dd828d97 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 20:04:21 +0800 Subject: [PATCH 1053/1089] more --- .../srt/managers/expert_distribution.py | 26 +++++++------------ 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 89915f6169f..f755b5c699e 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -506,7 +506,8 @@ def dump(self): ) def flush_buffer_depending_on_expert_location_metadata(self): - self._logical_count += _convert_global_physical_count_to_logical_count(self._buffer_global_physical_count, expert_location_metadata=self._expert_location_metadata) + self._logical_count += _convert_global_physical_count_to_logical_count(self._buffer_global_physical_count, + expert_location_metadata=self._expert_location_metadata) self._buffer_global_physical_count[...] = 0 @@ -514,20 +515,13 @@ def _convert_global_physical_count_to_logical_count( global_physical_count: torch.Tensor, expert_location_metadata: ExpertLocationMetadata, ): - logical_count = torch.zeros( - ( - expert_location_metadata.num_layers, - expert_location_metadata.num_logical_experts, - ) + num_layers = expert_location_metadata.num_layers + num_logical_experts = expert_location_metadata.num_logical_experts + + logical_count = torch.zeros((num_layers, num_logical_experts)) + logical_count.scatter_add_( + dim=1, + index=expert_location_metadata.physical_to_logical_map, + src=global_physical_count, ) - # Most naive implementation, can optimize if it is bottleneck - for layer_index in range(expert_location_metadata.num_layers): - for global_physical_expert_index in range(expert_location_metadata.num_physical_experts): - logical_expert_index = ( - expert_location_metadata.physical_to_logical_map[ - layer_index, global_physical_expert_index - ] - ) - logical_count[layer_index, logical_expert_index] += global_physical_count[ - layer_index, global_physical_expert_index] return logical_count From d9da5d2d7b6f3a25325818803f5135c3b9f0e25a Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 20:06:40 +0800 Subject: [PATCH 1054/1089] more --- .../model_executor/expert_location_updater.py | 4 ++-- .../srt/model_executor/model_weight_updater.py | 17 +++++++++++------ 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/python/sglang/srt/model_executor/expert_location_updater.py b/python/sglang/srt/model_executor/expert_location_updater.py index 12123163385..ec1c469a104 100644 --- a/python/sglang/srt/model_executor/expert_location_updater.py +++ b/python/sglang/srt/model_executor/expert_location_updater.py @@ -91,8 +91,8 @@ def _act(self): _log_with_accurate_time("ExpertLocationUpdater.act end") return UpdateExpertLocationReqOutput() - def _weight_filter(self, name: str, interesting_logical_experts_of_layer: Dict[int, List[int]]): - info: ModelParamNameInfo = self._model_runner.model.get_param_name_info(name) + def _weight_filter(self, _name: str, info: ModelParamNameInfo, + interesting_logical_experts_of_layer: Dict[int, List[int]]): return ( isinstance(info, ModelParamNameInfoMoe) and (info.expert_id in interesting_logical_experts_of_layer.get(info.layer_id, [])) diff --git a/python/sglang/srt/model_executor/model_weight_updater.py b/python/sglang/srt/model_executor/model_weight_updater.py index b14fcb8f8bb..ef7197b2e67 100644 --- a/python/sglang/srt/model_executor/model_weight_updater.py +++ b/python/sglang/srt/model_executor/model_weight_updater.py @@ -26,16 +26,16 @@ def __init__( self._model = model self._device = device - self._all_weights = _get_all_weights(load_format=load_format, model_config=model_config, model=model, - pin_memory=init_pin_memory) + self._all_weights_and_info = _get_all_weights_and_info(load_format=load_format, model_config=model_config, model=model, + pin_memory=init_pin_memory) self._memory_transfer_manager = AsyncToCudaManager() if init_pin_memory else CombinedManager.init_pin_memory_and_to_cuda() self._state: _State = _StateIdle() - def start_prepare(self, weight_filter: Callable[[str], bool]): + def start_prepare(self, weight_filter): assert isinstance(self._state, _StateIdle) - interesting_weights = [(name, weight) for name, weight in self._all_weights if weight_filter(name)] + interesting_weights = [(name, weight) for name, weight, info in self._all_weights_and_info if weight_filter(name, info)] self._memory_transfer_manager.enqueue(interesting_weights) self._state = _StateAwaitMemoryTransfer() @@ -85,7 +85,7 @@ class _StatePrepared(_State): named_tensors: List[Tuple[str, torch.Tensor]] -def _get_all_weights(load_format: str, model_config: ModelConfig, model, pin_memory: bool): +def _get_all_weights_and_info(load_format: str, model_config: ModelConfig, model, pin_memory: bool): load_config = LoadConfig(load_format=load_format) loader = get_model_loader(load_config) assert isinstance(loader, DefaultModelLoader) @@ -95,7 +95,12 @@ def _get_all_weights(load_format: str, model_config: ModelConfig, model, pin_mem if pin_memory: all_weights = _named_tensors_pin_memory(all_weights) - return all_weights + all_weights_and_info = [ + (name, weight, model.get_param_name_info(name)) + for name, weight in all_weights + ] + + return all_weights_and_info def _named_tensors_pin_memory(named_tensors: Iterable[Tuple[str, torch.Tensor]]) -> List[Tuple[str, torch.Tensor]]: From 5d82a94d6b9c8ec87c90a00bd73479cad4865b6e Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 20:10:54 +0800 Subject: [PATCH 1055/1089] more --- python/sglang/srt/model_executor/expert_location_updater.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/model_executor/expert_location_updater.py b/python/sglang/srt/model_executor/expert_location_updater.py index ec1c469a104..0b3f9024b57 100644 --- a/python/sglang/srt/model_executor/expert_location_updater.py +++ b/python/sglang/srt/model_executor/expert_location_updater.py @@ -49,7 +49,7 @@ def start(self, req: UpdateExpertLocationReqInput): ) self._model_weight_updater.start_prepare( - weight_filter=lambda name: self._weight_filter(name, interesting_logical_experts_of_layer), + weight_filter=lambda name, info: self._weight_filter(info, interesting_logical_experts_of_layer), ) _log_with_accurate_time("ExpertLocationUpdater.start end") @@ -91,7 +91,7 @@ def _act(self): _log_with_accurate_time("ExpertLocationUpdater.act end") return UpdateExpertLocationReqOutput() - def _weight_filter(self, _name: str, info: ModelParamNameInfo, + def _weight_filter(self, info: ModelParamNameInfo, interesting_logical_experts_of_layer: Dict[int, List[int]]): return ( isinstance(info, ModelParamNameInfoMoe) From 4c77a3ef1859fa5be8dff776ab7e2c393b9494d8 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 20:16:23 +0800 Subject: [PATCH 1056/1089] more --- python/sglang/srt/model_executor/memory_transfer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/model_executor/memory_transfer.py b/python/sglang/srt/model_executor/memory_transfer.py index 777c606d4a1..0282f79ba13 100644 --- a/python/sglang/srt/model_executor/memory_transfer.py +++ b/python/sglang/srt/model_executor/memory_transfer.py @@ -113,7 +113,7 @@ def get_outputs(self) -> List[NamedTensors]: return outputs def _handle_one_output(self, task: "_AsyncToCudaTask"): - torch.cuda.current_stream().wait_stream(self._alt_stream) + task.finish_event.wait(torch.cuda.current_stream()) return task.output_named_tensors def _auto_create_stream(self): From c49f0280bff853d6cafe01aff45f3753ca03a628 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 20:27:22 +0800 Subject: [PATCH 1057/1089] more --- python/sglang/srt/model_executor/memory_transfer.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/sglang/srt/model_executor/memory_transfer.py b/python/sglang/srt/model_executor/memory_transfer.py index 0282f79ba13..d904d9abf74 100644 --- a/python/sglang/srt/model_executor/memory_transfer.py +++ b/python/sglang/srt/model_executor/memory_transfer.py @@ -126,3 +126,6 @@ class _AsyncToCudaTask: finish_event: torch.cuda.Event input_named_tensors: NamedTensors output_named_tensors: NamedTensors + +class SimpleCachingAllocator: + TODO From a111bff525700328b6322eaf6e29aac7125362d6 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 20:27:46 +0800 Subject: [PATCH 1058/1089] more --- python/sglang/srt/model_executor/memory_transfer.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/model_executor/memory_transfer.py b/python/sglang/srt/model_executor/memory_transfer.py index d904d9abf74..b8eeecb782b 100644 --- a/python/sglang/srt/model_executor/memory_transfer.py +++ b/python/sglang/srt/model_executor/memory_transfer.py @@ -127,5 +127,13 @@ class _AsyncToCudaTask: input_named_tensors: NamedTensors output_named_tensors: NamedTensors + class SimpleCachingAllocator: - TODO + def __init__(self): + TODO + + def allocate(self) -> torch.Tensor: + return TODO + + def mark_all_unused(self): + TODO From 75d36e75cc7341ca7d4f4c315c595320c8036c84 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 20:28:23 +0800 Subject: [PATCH 1059/1089] more --- python/sglang/srt/model_executor/memory_transfer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/model_executor/memory_transfer.py b/python/sglang/srt/model_executor/memory_transfer.py index b8eeecb782b..682f942e86f 100644 --- a/python/sglang/srt/model_executor/memory_transfer.py +++ b/python/sglang/srt/model_executor/memory_transfer.py @@ -1,6 +1,7 @@ import logging import queue import traceback +from collections import defaultdict from dataclasses import dataclass from queue import SimpleQueue from threading import Thread @@ -130,7 +131,9 @@ class _AsyncToCudaTask: class SimpleCachingAllocator: def __init__(self): - TODO + # (size, dtype) -> list[Tensor] + self._unused_pool = defaultdict(list) + self._used_pool: List[torch.Tensor] = [] def allocate(self) -> torch.Tensor: return TODO From 50e4540e8b91b48b4d4622f7c394f2bc3fd584a1 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 20:28:53 +0800 Subject: [PATCH 1060/1089] more --- python/sglang/srt/model_executor/memory_transfer.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/model_executor/memory_transfer.py b/python/sglang/srt/model_executor/memory_transfer.py index 682f942e86f..6de74961cdf 100644 --- a/python/sglang/srt/model_executor/memory_transfer.py +++ b/python/sglang/srt/model_executor/memory_transfer.py @@ -135,7 +135,12 @@ def __init__(self): self._unused_pool = defaultdict(list) self._used_pool: List[torch.Tensor] = [] - def allocate(self) -> torch.Tensor: + def allocate(self, size, dtype) -> torch.Tensor: + unused_pool_entry = self._unused_pool[(size, dtype)] + if len(unused_pool_entry) > 0: + TODO + else: + TODO return TODO def mark_all_unused(self): From 0f24d392ef7b74db53afaf424b591c23c5386bd9 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 20:29:31 +0800 Subject: [PATCH 1061/1089] more --- python/sglang/srt/model_executor/memory_transfer.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/model_executor/memory_transfer.py b/python/sglang/srt/model_executor/memory_transfer.py index 6de74961cdf..f75d0337070 100644 --- a/python/sglang/srt/model_executor/memory_transfer.py +++ b/python/sglang/srt/model_executor/memory_transfer.py @@ -130,7 +130,9 @@ class _AsyncToCudaTask: class SimpleCachingAllocator: - def __init__(self): + def __init__(self, device): + self._device = device + # (size, dtype) -> list[Tensor] self._unused_pool = defaultdict(list) self._used_pool: List[torch.Tensor] = [] @@ -138,9 +140,9 @@ def __init__(self): def allocate(self, size, dtype) -> torch.Tensor: unused_pool_entry = self._unused_pool[(size, dtype)] if len(unused_pool_entry) > 0: - TODO + output = unused_pool_entry.pop() else: - TODO + output = torch.empty(size, dtype=dtype, device=self._device) return TODO def mark_all_unused(self): From 3a73d7f481a67499de86891e187b5a35cf5fc673 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 20:29:49 +0800 Subject: [PATCH 1062/1089] more --- python/sglang/srt/model_executor/memory_transfer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/model_executor/memory_transfer.py b/python/sglang/srt/model_executor/memory_transfer.py index f75d0337070..8b36545f9b4 100644 --- a/python/sglang/srt/model_executor/memory_transfer.py +++ b/python/sglang/srt/model_executor/memory_transfer.py @@ -143,7 +143,10 @@ def allocate(self, size, dtype) -> torch.Tensor: output = unused_pool_entry.pop() else: output = torch.empty(size, dtype=dtype, device=self._device) - return TODO + + self._used_pool.append(output) + + return output def mark_all_unused(self): TODO From 770fbffd2ad5e71b41ac7e447ce92fa48b5e8030 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 20:30:30 +0800 Subject: [PATCH 1063/1089] more --- python/sglang/srt/model_executor/memory_transfer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/model_executor/memory_transfer.py b/python/sglang/srt/model_executor/memory_transfer.py index 8b36545f9b4..2c1d5f994bc 100644 --- a/python/sglang/srt/model_executor/memory_transfer.py +++ b/python/sglang/srt/model_executor/memory_transfer.py @@ -149,4 +149,7 @@ def allocate(self, size, dtype) -> torch.Tensor: return output def mark_all_unused(self): - TODO + for tensor in self._used_pool: + self._unused_pool[(tensor.size, tensor.dtype)].append(tensor) + + self._used_pool.clear() From e298c9b3f33855c4f7ec20d76d7e503b4d1b72a4 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 20:31:42 +0800 Subject: [PATCH 1064/1089] more --- python/sglang/srt/model_executor/memory_transfer.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/model_executor/memory_transfer.py b/python/sglang/srt/model_executor/memory_transfer.py index 2c1d5f994bc..21c684f0e14 100644 --- a/python/sglang/srt/model_executor/memory_transfer.py +++ b/python/sglang/srt/model_executor/memory_transfer.py @@ -84,7 +84,8 @@ def _background_thread_entrypoint(self): # Can use cuMemCreate etc if we want to further remove a GPU->GPU copy class AsyncToCudaManager(TensorOperationManagerBase): - def __init__(self): + def __init__(self, allocator: "SimpleCachingAllocator"): + self._allocator = allocator self._inflight_tasks: List[_AsyncToCudaTask] = [] self._alt_stream: Optional[torch.cuda.Stream] = None @@ -121,6 +122,12 @@ def _auto_create_stream(self): if self._alt_stream is None: self._alt_stream = torch.cuda.Stream() + @staticmethod + def _tensor_to_cuda(input_tensor: torch.Tensor, allocator: "SimpleCachingAllocator"): + output_tensor = allocator.allocate(input_tensor.size, input_tensor.dtype) + output_tensor.copy_(input_tensor, non_blocking=True) + return output_tensor + @dataclass class _AsyncToCudaTask: From 169ddc57e484f4a57deaad53f36eddef090dfd1c Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 20:31:54 +0800 Subject: [PATCH 1065/1089] more --- python/sglang/srt/model_executor/memory_transfer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/model_executor/memory_transfer.py b/python/sglang/srt/model_executor/memory_transfer.py index 21c684f0e14..e04703a1bab 100644 --- a/python/sglang/srt/model_executor/memory_transfer.py +++ b/python/sglang/srt/model_executor/memory_transfer.py @@ -95,7 +95,7 @@ def enqueue(self, named_tensors: NamedTensors): self._alt_stream.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(self._alt_stream): output_named_tensors = [ - (name, tensor.to("cuda", non_blocking=True)) + (name, self._tensor_to_cuda(tensor, self._allocator)) for name, tensor in named_tensors ] finish_event = torch.cuda.Event() From 88c40a723cb13af7eaf69791b5c9f297afb1991b Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 20:33:04 +0800 Subject: [PATCH 1066/1089] more --- .../sglang/srt/model_executor/memory_transfer.py | 4 ++-- .../srt/model_executor/model_weight_updater.py | 16 +++++++++++----- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/model_executor/memory_transfer.py b/python/sglang/srt/model_executor/memory_transfer.py index e04703a1bab..dbc83acfcde 100644 --- a/python/sglang/srt/model_executor/memory_transfer.py +++ b/python/sglang/srt/model_executor/memory_transfer.py @@ -30,8 +30,8 @@ def __init__(self, manager_a: TensorOperationManagerBase, manager_b: TensorOpera self._manager_b = manager_b @classmethod - def init_pin_memory_and_to_cuda(cls): - return cls(manager_a=AsyncPinMemoryManager(), manager_b=AsyncToCudaManager()) + def init_pin_memory_and_to_cuda(cls, allocator: "SimpleCachingAllocator"): + return cls(manager_a=AsyncPinMemoryManager(), manager_b=AsyncToCudaManager(allocator)) def enqueue(self, named_tensors: NamedTensors): self._manager_a.enqueue(named_tensors) diff --git a/python/sglang/srt/model_executor/model_weight_updater.py b/python/sglang/srt/model_executor/model_weight_updater.py index ef7197b2e67..23c25f99ca0 100644 --- a/python/sglang/srt/model_executor/model_weight_updater.py +++ b/python/sglang/srt/model_executor/model_weight_updater.py @@ -1,12 +1,12 @@ import logging from abc import ABC from dataclasses import dataclass -from typing import Tuple, List, Callable, Iterable +from typing import Tuple, List, Iterable import torch from sglang.srt.configs.load_config import LoadConfig from sglang.srt.configs.model_config import ModelConfig -from sglang.srt.model_executor.memory_transfer import AsyncToCudaManager, CombinedManager +from sglang.srt.model_executor.memory_transfer import AsyncToCudaManager, CombinedManager, SimpleCachingAllocator from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader from sglang.srt.model_loader.utils import set_default_torch_dtype @@ -26,16 +26,22 @@ def __init__( self._model = model self._device = device - self._all_weights_and_info = _get_all_weights_and_info(load_format=load_format, model_config=model_config, model=model, + self._all_weights_and_info = _get_all_weights_and_info(load_format=load_format, model_config=model_config, + model=model, pin_memory=init_pin_memory) - self._memory_transfer_manager = AsyncToCudaManager() if init_pin_memory else CombinedManager.init_pin_memory_and_to_cuda() + self._transfer_allocator = SimpleCachingAllocator() + self._memory_transfer_manager = AsyncToCudaManager( + self._transfer_allocator) if init_pin_memory else CombinedManager.init_pin_memory_and_to_cuda( + self._transfer_allocator) self._state: _State = _StateIdle() def start_prepare(self, weight_filter): assert isinstance(self._state, _StateIdle) - interesting_weights = [(name, weight) for name, weight, info in self._all_weights_and_info if weight_filter(name, info)] + self._transfer_allocator.mark_all_unused() + interesting_weights = [(name, weight) for name, weight, info in self._all_weights_and_info if + weight_filter(name, info)] self._memory_transfer_manager.enqueue(interesting_weights) self._state = _StateAwaitMemoryTransfer() From b6a846730ce9a21ae7709d91515f91d01aa5b4dc Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 20:33:24 +0800 Subject: [PATCH 1067/1089] more --- python/sglang/srt/model_executor/model_weight_updater.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/model_executor/model_weight_updater.py b/python/sglang/srt/model_executor/model_weight_updater.py index 23c25f99ca0..798228b4053 100644 --- a/python/sglang/srt/model_executor/model_weight_updater.py +++ b/python/sglang/srt/model_executor/model_weight_updater.py @@ -29,7 +29,7 @@ def __init__( self._all_weights_and_info = _get_all_weights_and_info(load_format=load_format, model_config=model_config, model=model, pin_memory=init_pin_memory) - self._transfer_allocator = SimpleCachingAllocator() + self._transfer_allocator = SimpleCachingAllocator(device="cuda") self._memory_transfer_manager = AsyncToCudaManager( self._transfer_allocator) if init_pin_memory else CombinedManager.init_pin_memory_and_to_cuda( self._transfer_allocator) From 200fce16efc5050c350b4ffb0163c162183b71f3 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 20:34:19 +0800 Subject: [PATCH 1068/1089] fmt --- python/sglang/srt/entrypoints/engine.py | 9 +- .../sglang/srt/layers/moe/ep_moe/kernels.py | 11 +- python/sglang/srt/layers/moe/ep_moe/layer.py | 9 +- .../srt/layers/moe/ep_moe/token_dispatcher.py | 8 +- .../layers/moe/expert_location_dispatch.py | 36 ++-- python/sglang/srt/layers/moe/topk.py | 12 +- python/sglang/srt/managers/eplb_simulator.py | 86 ++++++---- .../srt/managers/expert_distribution.py | 160 +++++++++++------- python/sglang/srt/managers/expert_location.py | 81 ++++++--- python/sglang/srt/managers/scheduler.py | 4 +- .../srt/managers/scheduler_input_blocker.py | 4 +- .../sglang/srt/managers/tokenizer_manager.py | 28 ++- .../model_executor/expert_location_updater.py | 94 +++++++--- .../srt/model_executor/memory_transfer.py | 37 ++-- .../sglang/srt/model_executor/model_runner.py | 50 ++++-- .../model_executor/model_weight_updater.py | 66 +++++--- .../sglang/srt/model_loader/weight_utils.py | 3 +- python/sglang/srt/models/deepseek_v2.py | 81 +++++---- python/sglang/srt/models/qwen2_moe.py | 4 +- python/sglang/srt/poll_based_barrier.py | 6 +- python/sglang/srt/server_args.py | 84 ++++----- test/srt/run_suite.py | 2 +- test/srt/test_eplb.py | 51 +++--- test/srt/test_expert_distribution.py | 1 + test/srt/test_expert_location.py | 96 +++++++++-- 25 files changed, 663 insertions(+), 360 deletions(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 5a839412cf1..b98c37fa9e3 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -33,6 +33,7 @@ import zmq import zmq.asyncio from PIL.Image import Image + from sglang.srt.configs.model_config import ModelConfig from sglang.srt.managers.eplb_manager import EPLBManager from sglang.srt.managers.expert_location import ExpertLocationMetadata @@ -659,10 +660,14 @@ def _compute_initial_expert_location_metadata( data_dict = json.loads(Path(data).read_text()) if "physical_to_logical_map" in data_dict: - logger.info("init_expert_location from init_by_mapping using ServerArgs.init_expert_location") + logger.info( + "init_expert_location from init_by_mapping using ServerArgs.init_expert_location" + ) return ExpertLocationMetadata.init_by_mapping(server_args, **data_dict) elif "logical_count" in data_dict: - logger.info("init_expert_location from init_by_eplb using ServerArgs.init_expert_location") + logger.info( + "init_expert_location from init_by_eplb using ServerArgs.init_expert_location" + ) return ExpertLocationMetadata.init_by_eplb(server_args, **data_dict) else: raise NotImplementedError( diff --git a/python/sglang/srt/layers/moe/ep_moe/kernels.py b/python/sglang/srt/layers/moe/ep_moe/kernels.py index d1a2dee0ade..fde85d09f68 100644 --- a/python/sglang/srt/layers/moe/ep_moe/kernels.py +++ b/python/sglang/srt/layers/moe/ep_moe/kernels.py @@ -4,6 +4,7 @@ import torch import triton import triton.language as tl + from sglang.srt.distributed import get_tensor_model_parallel_rank from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8 from sglang.srt.utils import ( @@ -443,13 +444,13 @@ def gelu_and_mul_triton_kernel( * ( 1 + tanh( - kAlpha - * ( - gate_output - + 0.044715 * gate_output * gate_output * gate_output + kAlpha + * ( + gate_output + + 0.044715 * gate_output * gate_output * gate_output + ) ) ) - ) ) gate_output = gate_output.to(InDtype) diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index d049d36df69..bf270e2a6e6 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -3,6 +3,7 @@ from typing import Callable, List, Optional, Tuple import torch + from sglang.srt.layers.moe.expert_location_dispatch import ExpertLocationDispatchInfo from sglang.srt.managers.schedule_batch import get_global_expert_location_metadata @@ -277,7 +278,7 @@ def forward( BLOCK_SIZE=512, ) - seg_indptr_cur_rank = seg_indptr[self.start_expert_id: self.end_expert_id + 2] + seg_indptr_cur_rank = seg_indptr[self.start_expert_id : self.end_expert_id + 2] weight_indices_cur_rank = torch.arange( 0, self.num_experts_per_partition, @@ -480,7 +481,7 @@ def _weight_loader_physical( elif shard_id == "w1": param.data[expert_id][: self.intermediate_size, :] = loaded_weight elif shard_id == "w3": - param.data[expert_id][self.intermediate_size:, :] = loaded_weight + param.data[expert_id][self.intermediate_size :, :] = loaded_weight else: raise ValueError(f"Expected shard_id w1,w2 or w3 but got {shard_id}") @@ -512,11 +513,11 @@ def _load_fp8_scale( block_n, block_k = self.block_shape[0], self.block_shape[1] if shard_id == "w1": param_data[expert_id][ - : (self.intermediate_size + block_n - 1) // block_n, : + : (self.intermediate_size + block_n - 1) // block_n, : ] = loaded_weight elif shard_id == "w3": param_data[expert_id][ - (self.intermediate_size + block_n - 1) // block_n:, : + (self.intermediate_size + block_n - 1) // block_n :, : ] = loaded_weight else: # w2 param_data[expert_id] = loaded_weight diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index ca1cf67bdb4..5082c663d6a 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -1,4 +1,6 @@ -from sglang.srt.managers.expert_distribution import get_global_expert_distribution_recorder +from sglang.srt.managers.expert_distribution import ( + get_global_expert_distribution_recorder, +) from sglang.srt.utils import DeepEPMode, DisposibleTensor try: @@ -426,7 +428,9 @@ def dispatch_b( ): hook() if self.return_recv_hook else event.current_stream_wait() - get_global_expert_distribution_recorder().on_deepep_dispatch_low_latency(masked_m) + get_global_expert_distribution_recorder().on_deepep_dispatch_low_latency( + masked_m + ) reorder_topk_ids = seg_indptr = None diff --git a/python/sglang/srt/layers/moe/expert_location_dispatch.py b/python/sglang/srt/layers/moe/expert_location_dispatch.py index 78b59438b31..8fcab2af871 100644 --- a/python/sglang/srt/layers/moe/expert_location_dispatch.py +++ b/python/sglang/srt/layers/moe/expert_location_dispatch.py @@ -2,7 +2,11 @@ from typing import Literal, Optional import torch -from sglang.srt.managers.schedule_batch import global_server_args_dict, get_global_expert_location_metadata + +from sglang.srt.managers.schedule_batch import ( + get_global_expert_location_metadata, + global_server_args_dict, +) @dataclass @@ -23,14 +27,20 @@ def init_new(cls, ep_rank: int, layer_id: int): return cls( ep_dispatch_algorithm=ep_dispatch_algorithm, partial_logical_to_rank_dispatch_physical_map=expert_location_metadata.logical_to_rank_dispatch_physical_map[ - ep_rank, layer_id, :], - partial_logical_to_all_physical_map=expert_location_metadata.logical_to_all_physical_map[layer_id, :], + ep_rank, layer_id, : + ], + partial_logical_to_all_physical_map=expert_location_metadata.logical_to_all_physical_map[ + layer_id, : + ], partial_logical_to_all_physical_map_num_valid=expert_location_metadata.logical_to_all_physical_map_num_valid[ - layer_id, :], + layer_id, : + ], ) -def topk_ids_logical_to_physical(topk_ids: torch.Tensor, info: Optional[ExpertLocationDispatchInfo]) -> torch.Tensor: +def topk_ids_logical_to_physical( + topk_ids: torch.Tensor, info: Optional[ExpertLocationDispatchInfo] +) -> torch.Tensor: if info is None: return topk_ids @@ -41,19 +51,23 @@ def topk_ids_logical_to_physical(topk_ids: torch.Tensor, info: Optional[ExpertLo raise NotImplementedError -def _topk_ids_logical_to_physical_static(topk_ids: torch.Tensor, - info: Optional[ExpertLocationDispatchInfo]) -> torch.Tensor: +def _topk_ids_logical_to_physical_static( + topk_ids: torch.Tensor, info: Optional[ExpertLocationDispatchInfo] +) -> torch.Tensor: return info.partial_logical_to_rank_dispatch_physical_map[topk_ids] -def _topk_ids_logical_to_physical_random(topk_ids: torch.Tensor, - info: Optional[ExpertLocationDispatchInfo]) -> torch.Tensor: +def _topk_ids_logical_to_physical_random( + topk_ids: torch.Tensor, info: Optional[ExpertLocationDispatchInfo] +) -> torch.Tensor: topk_ids_original_shape = topk_ids.shape device = topk_ids.device topk_ids = topk_ids.flatten() - chosen_dispatch_index = (torch.randint(0, 65536, topk_ids.shape, dtype=torch.int32, device=device) - % info.partial_logical_to_all_physical_map_num_valid[topk_ids]) + chosen_dispatch_index = ( + torch.randint(0, 65536, topk_ids.shape, dtype=torch.int32, device=device) + % info.partial_logical_to_all_physical_map_num_valid[topk_ids] + ) topk_ids = info.partial_logical_to_all_physical_map[topk_ids, chosen_dispatch_index] topk_ids = topk_ids.view(topk_ids_original_shape) diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index b954207319b..2bfde25d783 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -16,10 +16,16 @@ import torch import torch.nn.functional as F -from sglang.srt.layers.moe.expert_location_dispatch import ExpertLocationDispatchInfo, topk_ids_logical_to_physical -from sglang.srt.managers.expert_distribution import get_global_expert_distribution_recorder + +from sglang.srt.layers.moe.expert_location_dispatch import ( + ExpertLocationDispatchInfo, + topk_ids_logical_to_physical, +) +from sglang.srt.managers.expert_distribution import ( + get_global_expert_distribution_recorder, +) from sglang.srt.managers.schedule_batch import global_server_args_dict -from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip, get_bool_env_var +from sglang.srt.utils import get_bool_env_var, get_compiler_backend, is_cuda, is_hip _is_cuda = is_cuda() _is_hip = is_hip() diff --git a/python/sglang/srt/managers/eplb_simulator.py b/python/sglang/srt/managers/eplb_simulator.py index d225bb9802b..fce7912e048 100644 --- a/python/sglang/srt/managers/eplb_simulator.py +++ b/python/sglang/srt/managers/eplb_simulator.py @@ -10,12 +10,13 @@ import einops import polars as pl import torch +from tqdm.auto import tqdm + from sglang.srt.managers import deepseek_eplb from sglang.srt.managers.expert_location import ( ExpertLocationMetadata, ModelConfigForExpertLocation, ) -from tqdm.auto import tqdm @dataclass @@ -36,9 +37,9 @@ class MyExpertLocationMetadata: @staticmethod def init_by_eplb( - server_args: MyServerArgs, - logical_count: torch.Tensor, - num_physical_experts: int, + server_args: MyServerArgs, + logical_count: torch.Tensor, + num_physical_experts: int, ): model_config_for_expert_location = _MY_MODEL_CONFIG_FOR_EXPERT_LOCATION @@ -72,10 +73,10 @@ def read_physical_count_of_forward_pass(dir_data: Path): for path in tqdm(list(dir_data.glob("*.pt"))): for record in torch.load(path, weights_only=True): assert ( - physical_count_of_forward_pass_id_and_rank[ - record["forward_pass_id"] - ].get(record["rank"]) - is None + physical_count_of_forward_pass_id_and_rank[ + record["forward_pass_id"] + ].get(record["rank"]) + is None ) physical_count_of_forward_pass_id_and_rank[record["forward_pass_id"]][ record["rank"] @@ -84,10 +85,14 @@ def read_physical_count_of_forward_pass(dir_data: Path): items = [] for forward_pass_id, physical_count_of_rank in sorted( - physical_count_of_forward_pass_id_and_rank.items() + physical_count_of_forward_pass_id_and_rank.items() ): physical_count_of_rank_tensor = torch.stack( - [physical_count for rank, physical_count in sorted(physical_count_of_rank.items())]).sum(dim=0) + [ + physical_count + for rank, physical_count in sorted(physical_count_of_rank.items()) + ] + ).sum(dim=0) items.append(physical_count_of_rank_tensor) physical_count_of_forward_pass = torch.stack(items) @@ -97,24 +102,24 @@ def read_physical_count_of_forward_pass(dir_data: Path): def scan_combinations( - logical_count_of_seq: torch.Tensor, - override_eplb_input_logical_count: Optional[torch.Tensor] = None, + logical_count_of_seq: torch.Tensor, + override_eplb_input_logical_count: Optional[torch.Tensor] = None, ): num_gpu_per_node = 8 server_args_list = [ *[ MyServerArgs( - num_tokens_in_batch_overall=num_tokens_in_batch_per_gpu * num_gpu_per_node * nnodes, + num_tokens_in_batch_overall=num_tokens_in_batch_per_gpu + * num_gpu_per_node + * nnodes, ep_num_redundant_experts=ep_num_redundant_experts, nnodes=nnodes, tp_size=num_gpu_per_node * nnodes, enable_expert_location_by_eplb=enable_expert_location_by_eplb, init_expert_location=init_expert_location, ) - # for init_expert_location in ["/host_home/temp_sglang_server2local/1744461420780309768.json", None] for init_expert_location in ["from_variable"] - # decode # for ep_num_redundant_experts in [0, 32] # for nnodes in [ @@ -125,7 +130,6 @@ def scan_combinations( # *([9] if ep_num_redundant_experts == 32 else []), # ] # for num_tokens_in_batch_per_gpu in [64, 128] - # prefill for ep_num_redundant_experts in [0, 32] for nnodes in [4] @@ -133,7 +137,6 @@ def scan_combinations( # for ep_num_redundant_experts in [0, 32, 64] # for nnodes in [1, 2, 4] # for num_tokens_in_batch_per_gpu in [1024, 4096, 8192, 16384] - for enable_expert_location_by_eplb in [ *([False] if ep_num_redundant_experts == 0 else []), True, @@ -145,7 +148,8 @@ def scan_combinations( for server_args in server_args_list: print() info = simulate_execution( - logical_count_of_seq=logical_count_of_seq, server_args=server_args, + logical_count_of_seq=logical_count_of_seq, + server_args=server_args, override_eplb_input_logical_count=override_eplb_input_logical_count, ) print(f"{server_args=} {info=}") @@ -170,9 +174,9 @@ def analyze_actual_utilization_rate(dir_data: Path, num_gpu: int): def simulate_execution( - logical_count_of_seq: torch.Tensor, - server_args: MyServerArgs, - override_eplb_input_logical_count: Optional[torch.Tensor] = None, + logical_count_of_seq: torch.Tensor, + server_args: MyServerArgs, + override_eplb_input_logical_count: Optional[torch.Tensor] = None, ): model_config_for_expert_location = _MY_MODEL_CONFIG_FOR_EXPERT_LOCATION @@ -184,20 +188,26 @@ def simulate_execution( if server_args.enable_expert_location_by_eplb: num_physical_expert = ( - model_config_for_expert_location.num_logical_experts - + server_args.ep_num_redundant_experts + model_config_for_expert_location.num_logical_experts + + server_args.ep_num_redundant_experts ) if server_args.init_expert_location == "from_variable": - print(f"Compute eplb_input_logical_count from override_eplb_input_logical_count") + print( + f"Compute eplb_input_logical_count from override_eplb_input_logical_count" + ) eplb_input_logical_count = override_eplb_input_logical_count elif (x := server_args.init_expert_location) is not None: print(f"Compute eplb_input_logical_count from {x}") - eplb_input_logical_count = torch.tensor(json.loads(Path(x).read_text())["logical_count"]) + eplb_input_logical_count = torch.tensor( + json.loads(Path(x).read_text())["logical_count"] + ) else: print(f"Compute eplb_input_logical_count from logical_count_of_seq") - eplb_input_logical_count = einops.einsum(logical_count_of_seq, - "num_seq num_layer num_expert -> num_layer num_expert", ) + eplb_input_logical_count = einops.einsum( + logical_count_of_seq, + "num_seq num_layer num_expert -> num_layer num_expert", + ) expert_location_metadata = MyExpertLocationMetadata.init_by_eplb( server_args, @@ -235,8 +245,8 @@ def simulate_execution( def simulate_batching( - logical_count_of_seq: torch.Tensor, # (num_seq, num_layer, num_logical_expert) - num_tokens_in_batch_overall: int, + logical_count_of_seq: torch.Tensor, # (num_seq, num_layer, num_logical_expert) + num_tokens_in_batch_overall: int, ) -> torch.Tensor: """output: (num_batch, num_layer, num_logical_expert)""" tensor_chunks = chunker( @@ -250,9 +260,9 @@ def simulate_batching( def simulate_logical_to_physical( - logical_count_of_whatever: torch.Tensor, # (*, num_layer, num_logical_expert) - logical_to_all_physical_map: torch.Tensor, # (num_layer, num_logical_experts, X) - num_physical_expert: int, + logical_count_of_whatever: torch.Tensor, # (*, num_layer, num_logical_expert) + logical_to_all_physical_map: torch.Tensor, # (num_layer, num_logical_experts, X) + num_physical_expert: int, ): """output: (*, num_layer, num_physical_expert)""" num_whatever, num_layer, num_logical_expert = logical_count_of_whatever.shape @@ -271,7 +281,7 @@ def simulate_logical_to_physical( ) for physical_expert_id in all_physical_expert_ids: physical_count_of_whatever[ - :, layer_id, physical_expert_id + :, layer_id, physical_expert_id ] += logical_count_of_whatever[:, layer_id, logical_expert_id] / len( all_physical_expert_ids ) @@ -280,8 +290,8 @@ def simulate_logical_to_physical( def compute_gpu_physical_count( - physical_count_of_whatever: torch.Tensor, # (whatever, num_layer, num_physical_expert) - num_gpu: int, + physical_count_of_whatever: torch.Tensor, # (whatever, num_layer, num_physical_expert) + num_gpu: int, ): """output: gpu_physical_count_of_batch (whatever, num_layer, num_gpu)""" return einops.reduce( @@ -293,7 +303,7 @@ def compute_gpu_physical_count( def compute_utilization_rate( - gpu_physical_count_of_batch: torch.Tensor, # (num_batch, num_layer, num_gpu) + gpu_physical_count_of_batch: torch.Tensor, # (num_batch, num_layer, num_gpu) ): """output: utilization_rate (num_batch, num_layer)""" gpu_physical_count_of_batch = gpu_physical_count_of_batch.float() @@ -311,7 +321,9 @@ def compute_utilization_rate( def compute_num_token(whatever_with_num_layer_and_num_expert: torch.Tensor): - num_token_mul_num_experts = whatever_with_num_layer_and_num_expert[..., -1, :].sum(dim=-1) + num_token_mul_num_experts = whatever_with_num_layer_and_num_expert[..., -1, :].sum( + dim=-1 + ) return num_token_mul_num_experts / _MY_MODEL_CONFIG_NUM_EXPERTS_PER_TOK diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index f755b5c699e..55a6c6069eb 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -8,6 +8,7 @@ from typing import Any, List, Optional, Type import torch + from sglang.srt.managers.expert_location import ExpertLocationMetadata from sglang.srt.server_args import ServerArgs from sglang.srt.utils import Withable, get_bool_env_var @@ -23,12 +24,14 @@ class ExpertDistributionRecorder: @staticmethod def init_new( - server_args: ServerArgs, - expert_location_metadata: "ExpertLocationMetadata", - rank: int, + server_args: ServerArgs, + expert_location_metadata: "ExpertLocationMetadata", + rank: int, ): if server_args.enable_expert_distribution_recorder: - return _ExpertDistributionRecorderReal(server_args, expert_location_metadata, rank) + return _ExpertDistributionRecorderReal( + server_args, expert_location_metadata, rank + ) else: return _ExpertDistributionRecorderNoop() @@ -50,7 +53,9 @@ def on_select_experts(self, topk_ids: torch.Tensor): def on_deepep_dispatch_normal(self, local_physical_count_of_layer: List[int]): pass - def on_deepep_dispatch_low_latency(self, local_physical_count_of_layer: torch.Tensor): + def on_deepep_dispatch_low_latency( + self, local_physical_count_of_layer: torch.Tensor + ): pass def start_record(self): @@ -64,7 +69,8 @@ def dump_record(self): def _on_not_implemented(self): raise Exception( - "Please enable ServerArgs.enable_expert_distribution_recorder to use ExpertDistributionRecorder.") + "Please enable ServerArgs.enable_expert_distribution_recorder to use ExpertDistributionRecorder." + ) class _ExpertDistributionRecorderNoop(ExpertDistributionRecorder): @@ -73,10 +79,10 @@ class _ExpertDistributionRecorderNoop(ExpertDistributionRecorder): class _ExpertDistributionRecorderReal(ExpertDistributionRecorder): def __init__( - self, - server_args: ServerArgs, - expert_location_metadata: "ExpertLocationMetadata", - rank: int, + self, + server_args: ServerArgs, + expert_location_metadata: "ExpertLocationMetadata", + rank: int, ): self._server_args = server_args self._expert_location_metadata = expert_location_metadata @@ -131,8 +137,13 @@ def on_deepep_dispatch_normal(self, local_physical_count_of_layer: List[int]): local_physical_count_of_layer=local_physical_count_of_layer, ) - def on_deepep_dispatch_low_latency(self, local_physical_count_of_layer: torch.Tensor): - self._on_hook("on_deepep_dispatch_low_latency", local_physical_count_of_layer=local_physical_count_of_layer) + def on_deepep_dispatch_low_latency( + self, local_physical_count_of_layer: torch.Tensor + ): + self._on_hook( + "on_deepep_dispatch_low_latency", + local_physical_count_of_layer=local_physical_count_of_layer, + ) def _on_hook(self, hook_name: str, **kwargs): if not (self._recording or torch.cuda.is_current_stream_capturing()): @@ -148,7 +159,7 @@ def _reset(self): """Reset the expert distribution recorder.""" logger.info("Resetting ExpertDistributionRecorder...") assert ( - self._current_layer_idx.value is None + self._current_layer_idx.value is None ), f"{self._current_layer_idx.value=}" for gatherer in self._single_pass_gatherers.values(): gatherer.reset() @@ -195,11 +206,9 @@ def set_global_expert_distribution_recorder(value): def postprocess_dumps( - dumps: List[Any], expert_location_metadata: "ExpertLocationMetadata" + dumps: List[Any], expert_location_metadata: "ExpertLocationMetadata" ): - return _Accumulator.get_class().postprocess_dumps( - dumps, expert_location_metadata - ) + return _Accumulator.get_class().postprocess_dumps(dumps, expert_location_metadata) # --------------------------------------- SinglePassGatherer ----------------------------------------- @@ -208,14 +217,18 @@ def postprocess_dumps( class _SinglePassGatherer(ABC): @staticmethod def init_new( - server_args: ServerArgs, expert_location_metadata: "ExpertLocationMetadata", rank: int, + server_args: ServerArgs, + expert_location_metadata: "ExpertLocationMetadata", + rank: int, ) -> "_SinglePassGatherer": if server_args.enable_deepep_moe: # `auto` has many restrictions now, so we lower the priority to implement low-latency capturing for auto if server_args.deepep_mode in ["normal", "auto"]: return _DeepepNormalSinglePassGatherer(expert_location_metadata, rank) elif server_args.deepep_mode == "low_latency": - return _DeepepLowLatencySinglePassGatherer(expert_location_metadata, rank) + return _DeepepLowLatencySinglePassGatherer( + expert_location_metadata, rank + ) else: raise NotImplementedError return _SelectExpertsSinglePassGatherer(expert_location_metadata, rank) @@ -228,11 +241,13 @@ def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor): pass def on_deepep_dispatch_normal( - self, layer_idx: int, local_physical_count_of_layer: List[int] + self, layer_idx: int, local_physical_count_of_layer: List[int] ): pass - def on_deepep_dispatch_low_latency(self, layer_idx: int, local_physical_count_of_layer: torch.Tensor): + def on_deepep_dispatch_low_latency( + self, layer_idx: int, local_physical_count_of_layer: torch.Tensor + ): pass def reset(self): @@ -247,9 +262,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._objects_of_layer = {} - def _on_layer_data( - self, layer_idx: int, objects: List[int] - ): + def _on_layer_data(self, layer_idx: int, objects: List[int]): assert layer_idx not in self._objects_of_layer assert 0 <= layer_idx < self._expert_location_metadata.num_layers self._objects_of_layer[layer_idx] = objects @@ -259,8 +272,7 @@ def reset(self): def _collect_objects(self, pad_len: int) -> torch.Tensor: data = [ - self._objects_of_layer.get(layer_index) - or ([0] * pad_len) + self._objects_of_layer.get(layer_index) or ([0] * pad_len) for layer_index in range(self._expert_location_metadata.num_layers) ] return torch.tensor(data) @@ -272,7 +284,9 @@ def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor): topk_ids_list = topk_ids.to("cpu", non_blocking=True).numpy().tolist() torch.cuda.synchronize() - global_physical_count = [0] * self._expert_location_metadata.num_physical_experts + global_physical_count = [ + 0 + ] * self._expert_location_metadata.num_physical_experts for token_record in topk_ids_list: for global_physical_expert_idx in token_record: global_physical_count[global_physical_expert_idx] += 1 @@ -280,19 +294,22 @@ def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor): self._on_layer_data(layer_idx, global_physical_count) def collect_global_physical_count(self) -> torch.Tensor: - return super()._collect_objects(pad_len=self._expert_location_metadata.num_physical_experts) + return super()._collect_objects( + pad_len=self._expert_location_metadata.num_physical_experts + ) class _DeepepNormalSinglePassGatherer(_LayerBasedSinglePassGatherer): def on_deepep_dispatch_normal( - self, layer_idx: int, local_physical_count_of_layer: List[int] + self, layer_idx: int, local_physical_count_of_layer: List[int] ): assert isinstance(local_physical_count_of_layer, list) self._on_layer_data(layer_idx, local_physical_count_of_layer) def collect_global_physical_count(self) -> torch.Tensor: local_physical_count = super()._collect_objects( - pad_len=self._expert_location_metadata.num_local_physical_experts) + pad_len=self._expert_location_metadata.num_local_physical_experts + ) return _convert_local_to_global_physical_count( local_physical_count, rank=self._rank, @@ -305,12 +322,17 @@ class _DeepepLowLatencySinglePassGatherer(_SinglePassGatherer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._data = torch.zeros( - (self._expert_location_metadata.num_layers, self._expert_location_metadata.num_local_physical_experts), + ( + self._expert_location_metadata.num_layers, + self._expert_location_metadata.num_local_physical_experts, + ), dtype=torch.int, device="cuda", ) - def on_deepep_dispatch_low_latency(self, layer_idx: int, local_physical_count_of_layer: torch.Tensor): + def on_deepep_dispatch_low_latency( + self, layer_idx: int, local_physical_count_of_layer: torch.Tensor + ): # Most naive implementation, can optimize later self._data[layer_idx, :] = local_physical_count_of_layer @@ -328,17 +350,19 @@ def collect_global_physical_count(self) -> torch.Tensor: def _convert_local_to_global_physical_count( - local_physical_count: torch.Tensor, - rank: int, - num_local_physical_experts: int, - num_physical_experts: int, + local_physical_count: torch.Tensor, + rank: int, + num_local_physical_experts: int, + num_physical_experts: int, ) -> torch.Tensor: dtype = local_physical_count.dtype device = local_physical_count.device num_layers, _ = local_physical_count.shape ans = torch.zeros((num_layers, num_physical_experts), dtype=dtype, device=device) - ans[:, num_local_physical_experts * rank:num_local_physical_experts * (rank + 1)] = local_physical_count + ans[ + :, num_local_physical_experts * rank : num_local_physical_experts * (rank + 1) + ] = local_physical_count return ans @@ -350,7 +374,7 @@ def _convert_local_to_global_physical_count( class _Accumulator(ABC): @staticmethod def init_new( - expert_location_metadata: "ExpertLocationMetadata", rank: int + expert_location_metadata: "ExpertLocationMetadata", rank: int ) -> "_Accumulator": return _Accumulator.get_class()(expert_location_metadata, rank) @@ -372,17 +396,17 @@ def get_single_pass_gatherer_key(self, debug_name: Optional[str]): @classmethod def postprocess_dumps( - cls, - dumps: List[Any], - expert_location_metadata: "ExpertLocationMetadata", + cls, + dumps: List[Any], + expert_location_metadata: "ExpertLocationMetadata", ): raise NotImplementedError def append( - self, - forward_pass_id: int, - gatherer_key: str, - single_pass_global_physical_count: torch.Tensor, + self, + forward_pass_id: int, + gatherer_key: str, + single_pass_global_physical_count: torch.Tensor, ): raise NotImplementedError @@ -399,9 +423,9 @@ def flush_buffer_depending_on_expert_location_metadata(self): class _DetailAccumulator(_Accumulator): @classmethod def postprocess_dumps( - cls, - dumps: List[Any], - expert_location_metadata: "ExpertLocationMetadata", + cls, + dumps: List[Any], + expert_location_metadata: "ExpertLocationMetadata", ): # Do not convert to logical since we want all details return [record for dump in dumps for record in dump] @@ -425,14 +449,16 @@ def get_single_pass_gatherer_key(self, debug_name: Optional[str]): return super().get_single_pass_gatherer_key(debug_name) def append( - self, - forward_pass_id: int, - gatherer_key: str, - single_pass_global_physical_count: torch.Tensor, + self, + forward_pass_id: int, + gatherer_key: str, + single_pass_global_physical_count: torch.Tensor, ): single_pass_global_physical_count = single_pass_global_physical_count.to("cpu") if self._save_dir is None: - single_pass_global_physical_count = single_pass_global_physical_count.tolist() + single_pass_global_physical_count = ( + single_pass_global_physical_count.tolist() + ) self._records.append( dict( @@ -462,11 +488,13 @@ def flush_buffer_depending_on_expert_location_metadata(self): class _StatAccumulator(_Accumulator): @classmethod def postprocess_dumps( - cls, - dumps: List[Any], - expert_location_metadata: "ExpertLocationMetadata", + cls, + dumps: List[Any], + expert_location_metadata: "ExpertLocationMetadata", ): - logical_count = torch.tensor([item["logical_count"] for item in dumps]).sum(dim=0) + logical_count = torch.tensor([item["logical_count"] for item in dumps]).sum( + dim=0 + ) return dict(logical_count=logical_count.tolist()) def __init__(self, expert_location_metadata: "ExpertLocationMetadata", rank: int): @@ -485,10 +513,10 @@ def __init__(self, expert_location_metadata: "ExpertLocationMetadata", rank: int ) def append( - self, - forward_pass_id: int, - gatherer_key: str, - single_pass_global_physical_count: torch.Tensor, + self, + forward_pass_id: int, + gatherer_key: str, + single_pass_global_physical_count: torch.Tensor, ): # Can optimize if overhead here is large self._buffer_global_physical_count += single_pass_global_physical_count @@ -506,14 +534,16 @@ def dump(self): ) def flush_buffer_depending_on_expert_location_metadata(self): - self._logical_count += _convert_global_physical_count_to_logical_count(self._buffer_global_physical_count, - expert_location_metadata=self._expert_location_metadata) + self._logical_count += _convert_global_physical_count_to_logical_count( + self._buffer_global_physical_count, + expert_location_metadata=self._expert_location_metadata, + ) self._buffer_global_physical_count[...] = 0 def _convert_global_physical_count_to_logical_count( - global_physical_count: torch.Tensor, - expert_location_metadata: ExpertLocationMetadata, + global_physical_count: torch.Tensor, + expert_location_metadata: ExpertLocationMetadata, ): num_layers = expert_location_metadata.num_layers num_logical_experts = expert_location_metadata.num_logical_experts diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 21b9a016a80..19852bf00f8 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -6,6 +6,7 @@ import torch import torch.nn.functional as F + from sglang.srt.configs.model_config import ModelConfig from sglang.srt.managers import deepseek_eplb from sglang.srt.model_loader import get_model_architecture @@ -46,9 +47,15 @@ def ep_size(self): def __post_init__(self): num_layers_0, num_physical_experts_0 = self.physical_to_logical_map.shape - num_layers_1, num_logical_experts_0, num_physical_experts_1 = self.logical_to_all_physical_map.shape - num_layers_2, num_logical_experts_1 = self.logical_to_all_physical_map_num_valid.shape - ep_size_0, num_layers_3, num_logical_experts_2 = self.logical_to_rank_dispatch_physical_map.shape + num_layers_1, num_logical_experts_0, num_physical_experts_1 = ( + self.logical_to_all_physical_map.shape + ) + num_layers_2, num_logical_experts_1 = ( + self.logical_to_all_physical_map_num_valid.shape + ) + ep_size_0, num_layers_3, num_logical_experts_2 = ( + self.logical_to_rank_dispatch_physical_map.shape + ) assert num_layers_0 == num_layers_1 == num_layers_2 == num_layers_3 assert num_logical_experts_0 == num_logical_experts_1 == num_logical_experts_2 assert num_physical_experts_0 == num_physical_experts_1 @@ -65,8 +72,8 @@ def init_trivial(server_args: ServerArgs): num_logical_experts = model_config_for_expert_location.num_logical_experts physical_to_logical_map = ( - torch.arange(0, num_physical_experts).repeat(num_layers, 1) - % num_logical_experts + torch.arange(0, num_physical_experts).repeat(num_layers, 1) + % num_logical_experts ) return ExpertLocationMetadata.init_by_mapping( @@ -123,8 +130,8 @@ def _init_common(server_args: ServerArgs): ) num_physical_experts = ( - model_config_for_expert_location.num_logical_experts - + server_args.ep_num_redundant_experts + model_config_for_expert_location.num_logical_experts + + server_args.ep_num_redundant_experts ) ep_size = server_args.ep_size assert num_physical_experts % ep_size == 0 @@ -139,9 +146,9 @@ def _init_common(server_args: ServerArgs): @staticmethod def _init_raw( - ep_size: int, - physical_to_logical_map: torch.Tensor, - logical_to_all_physical_map: torch.Tensor, + ep_size: int, + physical_to_logical_map: torch.Tensor, + logical_to_all_physical_map: torch.Tensor, ): _, num_physical_experts = physical_to_logical_map.shape @@ -151,7 +158,9 @@ def _init_raw( value=-1, ) - logical_to_all_physical_map_num_valid = torch.count_nonzero(logical_to_all_physical_map != -1, dim=-1) + logical_to_all_physical_map_num_valid = torch.count_nonzero( + logical_to_all_physical_map != -1, dim=-1 + ) return ExpertLocationMetadata( physical_to_logical_map=physical_to_logical_map, @@ -166,8 +175,12 @@ def _init_raw( # -------------------------------- mutation ------------------------------------ - def update(self, other: "ExpertLocationMetadata", layer_id_start: Optional[int] = None, - layer_id_len: Optional[int] = None): + def update( + self, + other: "ExpertLocationMetadata", + layer_id_start: Optional[int] = None, + layer_id_len: Optional[int] = None, + ): for field in [ "ep_size", ]: @@ -179,10 +192,13 @@ def update(self, other: "ExpertLocationMetadata", layer_id_start: Optional[int] ("logical_to_all_physical_map_num_valid", 0), ("logical_to_rank_dispatch_physical_map", 1), ]: + def _get(obj): ans = getattr(obj, field) if (layer_id_start is not None) or (layer_id_len is not None): - ans = ans.narrow(dim=layer_id_dim, start=layer_id_start, length=layer_id_len) + ans = ans.narrow( + dim=layer_id_dim, start=layer_id_start, length=layer_id_len + ) return ans # Cannot update address to avoid breaking CUDA graph @@ -203,7 +219,7 @@ def local_physical_to_physical(self, rank: int, local_physical_expert_index: int return self.num_local_physical_experts * rank + local_physical_expert_index def logical_to_all_physical( - self, layer_id: int, logical_expert_id: int + self, layer_id: int, logical_expert_id: int ) -> List[int]: return self.logical_to_all_physical_raw( self.logical_to_all_physical_map, layer_id, logical_expert_id @@ -211,7 +227,7 @@ def logical_to_all_physical( @staticmethod def logical_to_all_physical_raw( - logical_to_all_physical_map, layer_id: int, logical_expert_id: int + logical_to_all_physical_map, layer_id: int, logical_expert_id: int ) -> List[int]: return [ physical_expert_id @@ -231,7 +247,7 @@ def debug_str(self): def _compute_logical_to_all_physical_map( - physical_to_logical_map: torch.Tensor, num_logical_experts: int + physical_to_logical_map: torch.Tensor, num_logical_experts: int ): # This is rarely called, so we use for loops for maximum clarity @@ -267,10 +283,10 @@ def _pad_nested_array(arr, pad_value): # This is rarely called, so we use for loops for maximum clarity def compute_logical_to_rank_dispatch_physical_map( - logical_to_all_physical_map: torch.Tensor, - num_gpus: int, - num_physical_experts: int, - seed: int = 42, + logical_to_all_physical_map: torch.Tensor, + num_gpus: int, + num_physical_experts: int, + seed: int = 42, ): r = random.Random(seed) @@ -286,29 +302,40 @@ def compute_logical_to_rank_dispatch_physical_map( for layer_id in range(num_layers): for logical_expert_id in range(num_logical_experts): - candidate_physical_expert_ids = ExpertLocationMetadata.logical_to_all_physical_raw( - logical_to_all_physical_map, layer_id, logical_expert_id + candidate_physical_expert_ids = ( + ExpertLocationMetadata.logical_to_all_physical_raw( + logical_to_all_physical_map, layer_id, logical_expert_id + ) ) - output_partial = logical_to_rank_dispatch_physical_map[:, layer_id, logical_expert_id] + output_partial = logical_to_rank_dispatch_physical_map[ + :, layer_id, logical_expert_id + ] for gpu_id in range(num_gpus): same_gpu_physical_expert_ids = [ physical_expert_id for physical_expert_id in candidate_physical_expert_ids - if _compute_gpu_id_of_physical_expert(physical_expert_id, num_local_physical_experts) == gpu_id + if _compute_gpu_id_of_physical_expert( + physical_expert_id, num_local_physical_experts + ) + == gpu_id ] if len(same_gpu_physical_expert_ids) > 0: output_partial[gpu_id] = same_gpu_physical_expert_ids[0] num_remain = torch.sum(output_partial == -1).item() output_partial[output_partial == -1] = torch.tensor( - _fair_choices(candidate_physical_expert_ids, k=num_remain, r=r), dtype=dtype) + _fair_choices(candidate_physical_expert_ids, k=num_remain, r=r), + dtype=dtype, + ) assert torch.all(logical_to_rank_dispatch_physical_map != -1) return logical_to_rank_dispatch_physical_map -def _compute_gpu_id_of_physical_expert(physical_expert_id: int, num_local_physical_experts: int) -> int: +def _compute_gpu_id_of_physical_expert( + physical_expert_id: int, num_local_physical_experts: int +) -> int: return physical_expert_id // num_local_physical_experts diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 2537f5a2c60..3dd6161ee17 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -53,7 +53,9 @@ from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.layers.dp_attention import compute_dp_attention_world_info from sglang.srt.layers.logits_processor import LogitsProcessorOutput -from sglang.srt.managers.expert_distribution import get_global_expert_distribution_recorder +from sglang.srt.managers.expert_distribution import ( + get_global_expert_distribution_recorder, +) from sglang.srt.managers.expert_location import ExpertLocationMetadata from sglang.srt.managers.io_struct import ( AbortReq, diff --git a/python/sglang/srt/managers/scheduler_input_blocker.py b/python/sglang/srt/managers/scheduler_input_blocker.py index 30a26e3cc1e..f6c141a11ad 100644 --- a/python/sglang/srt/managers/scheduler_input_blocker.py +++ b/python/sglang/srt/managers/scheduler_input_blocker.py @@ -39,7 +39,9 @@ def handle(self, recv_reqs: Optional[List[Any]]): for recv_req in recv_reqs: output_reqs += self._handle_recv_req(recv_req) - global_arrived_unblock_barrier = self._global_unblock_barrier.poll_global_arrived() + global_arrived_unblock_barrier = ( + self._global_unblock_barrier.poll_global_arrived() + ) if ( self._state == _State.GLOBAL_UNBLOCK_BARRIER and global_arrived_unblock_barrier diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 39041f721c7..7b8af1fbb40 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -45,6 +45,7 @@ import zmq import zmq.asyncio from fastapi import BackgroundTasks + from sglang.srt.aio_rwlock import RWLock from sglang.srt.configs.model_config import ModelConfig from sglang.srt.disaggregation.conn import KVBootstrapServer @@ -716,18 +717,27 @@ async def update_expert_location(self, obj: UpdateExpertLocationReqInput): for layer_id_end in layer_id_lens: logger.info(f"update_expert_location handling up to {layer_id_end}th layer") - partial_expert_location_metadata = copy.deepcopy(old_expert_location_metadata) - partial_expert_location_metadata.update(obj.expert_location_metadata, layer_id_start=0, - layer_id_len=layer_id_end) + partial_expert_location_metadata = copy.deepcopy( + old_expert_location_metadata + ) + partial_expert_location_metadata.update( + obj.expert_location_metadata, + layer_id_start=0, + layer_id_len=layer_id_end, + ) await self._update_expert_location_raw( expert_location_metadata=partial_expert_location_metadata, ) - async def _update_expert_location_raw(self, expert_location_metadata: ExpertLocationMetadata): + async def _update_expert_location_raw( + self, expert_location_metadata: ExpertLocationMetadata + ): self.expert_location_metadata = None - await self.update_expert_location_communicator(UpdateExpertLocationReqInput( - expert_location_metadata=expert_location_metadata, - )) + await self.update_expert_location_communicator( + UpdateExpertLocationReqInput( + expert_location_metadata=expert_location_metadata, + ) + ) self.expert_location_metadata = expert_location_metadata async def update_weights_from_disk( @@ -1037,8 +1047,8 @@ def _handle_batch_output( elif isinstance(recv_obj, BatchTokenIDOut): if self.server_args.stream_output and state.obj.stream: output_token_ids = recv_obj.output_ids[i][ - state.last_output_offset: - ] + state.last_output_offset : + ] state.last_output_offset = len(recv_obj.output_ids[i]) else: output_token_ids = recv_obj.output_ids[i] diff --git a/python/sglang/srt/model_executor/expert_location_updater.py b/python/sglang/srt/model_executor/expert_location_updater.py index 0b3f9024b57..f1801487505 100644 --- a/python/sglang/srt/model_executor/expert_location_updater.py +++ b/python/sglang/srt/model_executor/expert_location_updater.py @@ -3,12 +3,21 @@ from typing import TYPE_CHECKING, Dict, List, Optional import torch -from sglang.srt.managers.expert_distribution import get_global_expert_distribution_recorder + +from sglang.srt.managers.expert_distribution import ( + get_global_expert_distribution_recorder, +) from sglang.srt.managers.expert_location import ExpertLocationMetadata -from sglang.srt.managers.io_struct import UpdateExpertLocationReqInput, UpdateExpertLocationReqOutput +from sglang.srt.managers.io_struct import ( + UpdateExpertLocationReqInput, + UpdateExpertLocationReqOutput, +) from sglang.srt.managers.schedule_batch import get_global_expert_location_metadata from sglang.srt.model_executor.model_weight_updater import ModelWeightUpdater -from sglang.srt.model_loader.weight_utils import ModelParamNameInfo, ModelParamNameInfoMoe +from sglang.srt.model_loader.weight_utils import ( + ModelParamNameInfo, + ModelParamNameInfoMoe, +) from sglang.srt.poll_based_barrier import PollBasedBarrier from sglang.srt.utils import get_bool_env_var @@ -42,14 +51,18 @@ def start(self, req: UpdateExpertLocationReqInput): assert self._ongoing_req is None self._ongoing_req = req - interesting_logical_experts_of_layer = _compute_interesting_logical_experts_of_layer( - old_expert_location_metadata=get_global_expert_location_metadata(), - new_expert_location_metadata=req.expert_location_metadata, - ep_rank=self._model_runner.tp_rank, + interesting_logical_experts_of_layer = ( + _compute_interesting_logical_experts_of_layer( + old_expert_location_metadata=get_global_expert_location_metadata(), + new_expert_location_metadata=req.expert_location_metadata, + ep_rank=self._model_runner.tp_rank, + ) ) self._model_weight_updater.start_prepare( - weight_filter=lambda name, info: self._weight_filter(info, interesting_logical_experts_of_layer), + weight_filter=lambda name, info: self._weight_filter( + info, interesting_logical_experts_of_layer + ), ) _log_with_accurate_time("ExpertLocationUpdater.start end") @@ -57,7 +70,9 @@ def event_loop_step(self) -> List[UpdateExpertLocationReqOutput]: outputs = [] if self._model_weight_updater.poll_prepare_end(): - _log_with_accurate_time("ExpertLocationUpdater.event_loop_step observe local_arrive") + _log_with_accurate_time( + "ExpertLocationUpdater.event_loop_step observe local_arrive" + ) self._prepare_end_barrier.local_arrive() if self._prepare_end_barrier.poll_global_arrived(): @@ -71,17 +86,23 @@ def _act(self): get_global_expert_distribution_recorder().flush_buffer_depending_on_expert_location_metadata() - get_global_expert_location_metadata().update(self._ongoing_req.expert_location_metadata) + get_global_expert_location_metadata().update( + self._ongoing_req.expert_location_metadata + ) if self._model_runner.tp_rank == 0 and get_bool_env_var( - "SGLANG_LOG_EXPERT_LOCATION_METADATA" + "SGLANG_LOG_EXPERT_LOCATION_METADATA" ): logger.info( f"Updated expert_location_metadata: {get_global_expert_location_metadata().debug_str()}" ) - _log_with_accurate_time("ExpertLocationUpdater.act execute ModelWeightUpdater.act start") + _log_with_accurate_time( + "ExpertLocationUpdater.act execute ModelWeightUpdater.act start" + ) self._model_weight_updater.act() - _log_with_accurate_time("ExpertLocationUpdater.act execute ModelWeightUpdater.act end") + _log_with_accurate_time( + "ExpertLocationUpdater.act execute ModelWeightUpdater.act end" + ) torch.distributed.barrier() @@ -91,33 +112,50 @@ def _act(self): _log_with_accurate_time("ExpertLocationUpdater.act end") return UpdateExpertLocationReqOutput() - def _weight_filter(self, info: ModelParamNameInfo, - interesting_logical_experts_of_layer: Dict[int, List[int]]): - return ( - isinstance(info, ModelParamNameInfoMoe) - and (info.expert_id in interesting_logical_experts_of_layer.get(info.layer_id, [])) + def _weight_filter( + self, + info: ModelParamNameInfo, + interesting_logical_experts_of_layer: Dict[int, List[int]], + ): + return isinstance(info, ModelParamNameInfoMoe) and ( + info.expert_id + in interesting_logical_experts_of_layer.get(info.layer_id, []) ) def _compute_interesting_logical_experts_of_layer( - old_expert_location_metadata: ExpertLocationMetadata, - new_expert_location_metadata: ExpertLocationMetadata, - ep_rank: int, + old_expert_location_metadata: ExpertLocationMetadata, + new_expert_location_metadata: ExpertLocationMetadata, + ep_rank: int, ) -> Dict[int, List[int]]: num_layers = old_expert_location_metadata.num_layers num_local_physical_experts = old_expert_location_metadata.num_local_physical_experts - def _get_partial_physical_to_logical_map(meta: ExpertLocationMetadata, layer_id: int): - return meta.physical_to_logical_map[layer_id, - num_local_physical_experts * ep_rank: num_local_physical_experts * (ep_rank + 1)] + def _get_partial_physical_to_logical_map( + meta: ExpertLocationMetadata, layer_id: int + ): + return meta.physical_to_logical_map[ + layer_id, + num_local_physical_experts + * ep_rank : num_local_physical_experts + * (ep_rank + 1), + ] interesting_logical_experts_of_layer = {} for layer_id in range(num_layers): - old_partial_map = _get_partial_physical_to_logical_map(old_expert_location_metadata, layer_id) - new_partial_map = _get_partial_physical_to_logical_map(new_expert_location_metadata, layer_id) - interesting_logical_experts_of_layer[layer_id] = new_partial_map[new_partial_map != old_partial_map].tolist() + old_partial_map = _get_partial_physical_to_logical_map( + old_expert_location_metadata, layer_id + ) + new_partial_map = _get_partial_physical_to_logical_map( + new_expert_location_metadata, layer_id + ) + interesting_logical_experts_of_layer[layer_id] = new_partial_map[ + new_partial_map != old_partial_map + ].tolist() return interesting_logical_experts_of_layer def _log_with_accurate_time(message): - logger.info(f"[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')}] {message}") + logger.info( + f"[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')}] {message}" + ) diff --git a/python/sglang/srt/model_executor/memory_transfer.py b/python/sglang/srt/model_executor/memory_transfer.py index dbc83acfcde..83e0e39e383 100644 --- a/python/sglang/srt/model_executor/memory_transfer.py +++ b/python/sglang/srt/model_executor/memory_transfer.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from queue import SimpleQueue from threading import Thread -from typing import List, Tuple, Optional +from typing import List, Optional, Tuple import torch @@ -24,14 +24,20 @@ def get_outputs(self) -> List[NamedTensors]: class CombinedManager(TensorOperationManagerBase): - def __init__(self, manager_a: TensorOperationManagerBase, manager_b: TensorOperationManagerBase): + def __init__( + self, + manager_a: TensorOperationManagerBase, + manager_b: TensorOperationManagerBase, + ): # For simplicity, only support chaining 2 managers, but can be extended to N self._manager_a = manager_a self._manager_b = manager_b @classmethod def init_pin_memory_and_to_cuda(cls, allocator: "SimpleCachingAllocator"): - return cls(manager_a=AsyncPinMemoryManager(), manager_b=AsyncToCudaManager(allocator)) + return cls( + manager_a=AsyncPinMemoryManager(), manager_b=AsyncToCudaManager(allocator) + ) def enqueue(self, named_tensors: NamedTensors): self._manager_a.enqueue(named_tensors) @@ -74,7 +80,9 @@ def _background_thread_entrypoint(self): try: while True: input_data = self._input_queue.get() - output_data = [(name, tensor.pin_memory()) for name, tensor in input_data] + output_data = [ + (name, tensor.pin_memory()) for name, tensor in input_data + ] self._output_queue.put(output_data) except Exception as e: logger.warning(f"AsyncPinMemoryManager background thread error {e}") @@ -101,15 +109,20 @@ def enqueue(self, named_tensors: NamedTensors): finish_event = torch.cuda.Event() finish_event.record() - self._inflight_tasks.append(_AsyncToCudaTask( - finish_event=finish_event, - input_named_tensors=named_tensors, - output_named_tensors=output_named_tensors, - )) + self._inflight_tasks.append( + _AsyncToCudaTask( + finish_event=finish_event, + input_named_tensors=named_tensors, + output_named_tensors=output_named_tensors, + ) + ) def get_outputs(self) -> List[NamedTensors]: outputs = [] - while len(self._inflight_tasks) > 0 and self._inflight_tasks[0].finish_event.query(): + while ( + len(self._inflight_tasks) > 0 + and self._inflight_tasks[0].finish_event.query() + ): task = self._inflight_tasks.pop(0) outputs.append(self._handle_one_output(task)) return outputs @@ -123,7 +136,9 @@ def _auto_create_stream(self): self._alt_stream = torch.cuda.Stream() @staticmethod - def _tensor_to_cuda(input_tensor: torch.Tensor, allocator: "SimpleCachingAllocator"): + def _tensor_to_cuda( + input_tensor: torch.Tensor, allocator: "SimpleCachingAllocator" + ): output_tensor = allocator.allocate(input_tensor.size, input_tensor.dtype) output_tensor.copy_(input_tensor, non_blocking=True) return output_tensor diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 0fa690126de..14d5bdccd28 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -20,10 +20,11 @@ import os import time from dataclasses import dataclass -from typing import Iterable, List, Optional, Tuple, Union, Any +from typing import Any, Iterable, List, Optional, Tuple, Union import torch import torch.distributed as dist + from sglang.srt.configs.device_config import DeviceConfig from sglang.srt.configs.load_config import LoadConfig from sglang.srt.configs.model_config import AttentionArch, ModelConfig @@ -33,7 +34,10 @@ initialize_model_parallel, set_custom_all_reduce, ) -from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state, get_world_group +from sglang.srt.distributed.parallel_state import ( + get_world_group, + monkey_patch_vllm_parallel_state, +) from sglang.srt.layers.dp_attention import ( get_attention_tp_group, get_attention_tp_size, @@ -44,8 +48,11 @@ from sglang.srt.layers.sampler import Sampler from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model from sglang.srt.lora.lora_manager import LoRAManager -from sglang.srt.managers.expert_distribution import get_global_expert_distribution_recorder, \ - set_global_expert_distribution_recorder, ExpertDistributionRecorder +from sglang.srt.managers.expert_distribution import ( + ExpertDistributionRecorder, + get_global_expert_distribution_recorder, + set_global_expert_distribution_recorder, +) from sglang.srt.managers.expert_location import ExpertLocationMetadata from sglang.srt.managers.io_struct import UpdateExpertLocationReqInput from sglang.srt.managers.schedule_batch import ( @@ -82,6 +89,7 @@ from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.utils import ( MultiprocessingSerializer, + broadcast_pyobj, enable_show_time_cost, get_available_gpu_memory, get_bool_env_var, @@ -92,7 +100,7 @@ monkey_patch_p2p_access_check, monkey_patch_vllm_gguf_config, set_cpu_offload_max_bytes, - set_cuda_arch, broadcast_pyobj, + set_cuda_arch, ) logger = logging.getLogger(__name__) @@ -189,7 +197,7 @@ def __init__( ) # CPU offload - set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024 ** 3)) + set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3)) # Get memory before model loading min_per_gpu_memory = self.init_torch_distributed() @@ -211,8 +219,11 @@ def __init__( # If it is a draft model tp_group can be different. self.initialize(min_per_gpu_memory) - self._expert_location_updater = ExpertLocationUpdater( - self) if server_args.expert_location_updater_mode is not None else None + self._expert_location_updater = ( + ExpertLocationUpdater(self) + if server_args.expert_location_updater_mode is not None + else None + ) def initialize(self, min_per_gpu_memory: float): server_args = self.server_args @@ -220,11 +231,13 @@ def initialize(self, min_per_gpu_memory: float): enable=self.server_args.enable_memory_saver ) - set_global_expert_distribution_recorder(ExpertDistributionRecorder.init_new( - server_args, - get_global_expert_location_metadata(), - rank=self.tp_rank, - )) + set_global_expert_distribution_recorder( + ExpertDistributionRecorder.init_new( + server_args, + get_global_expert_location_metadata(), + rank=self.tp_rank, + ) + ) # Load the model self.sampler = Sampler() @@ -536,7 +549,10 @@ def filter_weight_iter(iter: Iterable[Tuple[str, torch.Tensor]]): yield from iter else: for name, weight in iter: - if self.model.get_param_name_info(name).category in param_categories: + if ( + self.model.get_param_name_info(name).category + in param_categories + ): yield name, weight def model_load_weights(model, iter): @@ -955,7 +971,7 @@ def init_double_sparsity_channel_config(self, selected_channel): key = "model.layers." + str(i) + ".self_attn" + selected_channel self.sorted_channels.append( torch.tensor(channel_config[key])[ - :, : self.server_args.ds_heavy_channel_num + :, : self.server_args.ds_heavy_channel_num ] .contiguous() .cuda() @@ -1039,7 +1055,9 @@ def forward( self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False ) -> LogitsProcessorOutput: self.forward_pass_id += 1 - with get_global_expert_distribution_recorder().with_forward_pass(self.forward_pass_id): + with get_global_expert_distribution_recorder().with_forward_pass( + self.forward_pass_id + ): return self._forward_raw(forward_batch, skip_attn_backend_init) def _forward_raw( diff --git a/python/sglang/srt/model_executor/model_weight_updater.py b/python/sglang/srt/model_executor/model_weight_updater.py index 798228b4053..832128239c0 100644 --- a/python/sglang/srt/model_executor/model_weight_updater.py +++ b/python/sglang/srt/model_executor/model_weight_updater.py @@ -1,12 +1,17 @@ import logging from abc import ABC from dataclasses import dataclass -from typing import Tuple, List, Iterable +from typing import Iterable, List, Tuple import torch + from sglang.srt.configs.load_config import LoadConfig from sglang.srt.configs.model_config import ModelConfig -from sglang.srt.model_executor.memory_transfer import AsyncToCudaManager, CombinedManager, SimpleCachingAllocator +from sglang.srt.model_executor.memory_transfer import ( + AsyncToCudaManager, + CombinedManager, + SimpleCachingAllocator, +) from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader from sglang.srt.model_loader.utils import set_default_torch_dtype @@ -15,24 +20,29 @@ class ModelWeightUpdater: def __init__( - self, - init_pin_memory: bool, - load_format: str, - model_config: ModelConfig, - model, - device, + self, + init_pin_memory: bool, + load_format: str, + model_config: ModelConfig, + model, + device, ): self._model_config = model_config self._model = model self._device = device - self._all_weights_and_info = _get_all_weights_and_info(load_format=load_format, model_config=model_config, - model=model, - pin_memory=init_pin_memory) + self._all_weights_and_info = _get_all_weights_and_info( + load_format=load_format, + model_config=model_config, + model=model, + pin_memory=init_pin_memory, + ) self._transfer_allocator = SimpleCachingAllocator(device="cuda") - self._memory_transfer_manager = AsyncToCudaManager( - self._transfer_allocator) if init_pin_memory else CombinedManager.init_pin_memory_and_to_cuda( - self._transfer_allocator) + self._memory_transfer_manager = ( + AsyncToCudaManager(self._transfer_allocator) + if init_pin_memory + else CombinedManager.init_pin_memory_and_to_cuda(self._transfer_allocator) + ) self._state: _State = _StateIdle() @@ -40,8 +50,11 @@ def start_prepare(self, weight_filter): assert isinstance(self._state, _StateIdle) self._transfer_allocator.mark_all_unused() - interesting_weights = [(name, weight) for name, weight, info in self._all_weights_and_info if - weight_filter(name, info)] + interesting_weights = [ + (name, weight) + for name, weight, info in self._all_weights_and_info + if weight_filter(name, info) + ] self._memory_transfer_manager.enqueue(interesting_weights) self._state = _StateAwaitMemoryTransfer() @@ -67,7 +80,9 @@ def act(self): # TODO further extract such common operations during weight loading with set_default_torch_dtype(self._model_config.dtype): - DefaultModelLoader.load_weights_and_postprocess(self._model, named_tensors, target_device) + DefaultModelLoader.load_weights_and_postprocess( + self._model, named_tensors, target_device + ) self._state = _StateIdle() @@ -91,23 +106,30 @@ class _StatePrepared(_State): named_tensors: List[Tuple[str, torch.Tensor]] -def _get_all_weights_and_info(load_format: str, model_config: ModelConfig, model, pin_memory: bool): +def _get_all_weights_and_info( + load_format: str, model_config: ModelConfig, model, pin_memory: bool +): load_config = LoadConfig(load_format=load_format) loader = get_model_loader(load_config) assert isinstance(loader, DefaultModelLoader) with set_default_torch_dtype(model_config.dtype): - all_weights = list(loader._get_weights_iterator(DefaultModelLoader.Source.init_new(model_config, model))) + all_weights = list( + loader._get_weights_iterator( + DefaultModelLoader.Source.init_new(model_config, model) + ) + ) if pin_memory: all_weights = _named_tensors_pin_memory(all_weights) all_weights_and_info = [ - (name, weight, model.get_param_name_info(name)) - for name, weight in all_weights + (name, weight, model.get_param_name_info(name)) for name, weight in all_weights ] return all_weights_and_info -def _named_tensors_pin_memory(named_tensors: Iterable[Tuple[str, torch.Tensor]]) -> List[Tuple[str, torch.Tensor]]: +def _named_tensors_pin_memory( + named_tensors: Iterable[Tuple[str, torch.Tensor]] +) -> List[Tuple[str, torch.Tensor]]: return [(name, tensor.pin_memory()) for name, tensor in named_tensors] diff --git a/python/sglang/srt/model_loader/weight_utils.py b/python/sglang/srt/model_loader/weight_utils.py index e86001b41bd..5d44bcbd981 100644 --- a/python/sglang/srt/model_loader/weight_utils.py +++ b/python/sglang/srt/model_loader/weight_utils.py @@ -29,12 +29,13 @@ import torch from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download from pydantic import BaseModel, ConfigDict, ValidationInfo, model_validator +from tqdm.auto import tqdm + from sglang.srt.configs.load_config import LoadConfig from sglang.srt.configs.model_config import ModelConfig from sglang.srt.distributed import get_tensor_model_parallel_rank from sglang.srt.layers.quantization import QuantizationConfig, get_quantization_config from sglang.srt.utils import print_warning_once -from tqdm.auto import tqdm logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index fce64a9c4f5..47fd2a58130 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -25,6 +25,10 @@ import torch import torch.nn.functional as F +from torch import nn +from tqdm import tqdm +from transformers import PretrainedConfig + from sglang.srt.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -69,7 +73,9 @@ ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.managers.expert_distribution import get_global_expert_distribution_recorder +from sglang.srt.managers.expert_distribution import ( + get_global_expert_distribution_recorder, +) from sglang.srt.managers.expert_location import ( ExpertLocationMetadata, ModelConfigForExpertLocation, @@ -79,12 +85,13 @@ global_server_args_dict, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode -from sglang.srt.model_loader.weight_utils import default_weight_loader, ModelParamNameInfo, ModelParamNameInfoMoe, \ - ModelParamNameInfoOthers +from sglang.srt.model_loader.weight_utils import ( + ModelParamNameInfo, + ModelParamNameInfoMoe, + ModelParamNameInfoOthers, + default_weight_loader, +) from sglang.srt.utils import DeepEPMode, add_prefix, is_cuda, is_hip -from torch import nn -from tqdm import tqdm -from transformers import PretrainedConfig _is_hip = is_hip() _is_cuda = is_cuda() @@ -216,8 +223,8 @@ def __init__( self.experts = MoEImpl( num_experts=config.n_routed_experts - + self.n_share_experts_fusion - + global_server_args_dict["ep_num_redundant_experts"], + + self.n_share_experts_fusion + + global_server_args_dict["ep_num_redundant_experts"], top_k=config.num_experts_per_tok + min(self.n_share_experts_fusion, 1), hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, @@ -282,7 +289,7 @@ def __init__( router_topk=self.top_k, permute_fusion=True, num_experts=config.n_routed_experts - + global_server_args_dict["ep_num_redundant_experts"], + + global_server_args_dict["ep_num_redundant_experts"], num_local_experts=config.n_routed_experts // self.tp_size, hidden_size=config.hidden_size, params_dtype=config.torch_dtype, @@ -440,7 +447,7 @@ def __init__( self.num_heads = num_heads assert num_heads % attn_tp_size == 0 self.num_local_heads = num_heads // attn_tp_size - self.scaling = self.qk_head_dim ** -0.5 + self.scaling = self.qk_head_dim**-0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings @@ -553,12 +560,12 @@ def forward( kv = self.kv_b_proj(kv_a)[0] kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - k_pe = latent_cache[:, :, self.kv_lora_rank:] + k_pe = latent_cache[:, :, self.kv_lora_rank :] q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - q[..., self.qk_nope_head_dim:] = q_pe + q[..., self.qk_nope_head_dim :] = q_pe k = torch.empty_like(q) k[..., : self.qk_nope_head_dim] = k_nope - k[..., self.qk_nope_head_dim:] = k_pe + k[..., self.qk_nope_head_dim :] = k_pe q = torch.nn.functional.pad(q, [0, 256 - self.qk_head_dim], value=0).view( -1, self.num_local_heads * 256 ) @@ -570,8 +577,8 @@ def forward( ) attn_output = self.attn(q, k, v, forward_batch) attn_output = attn_output.view(-1, self.num_local_heads, 256)[ - ..., : self.v_head_dim - ].reshape(-1, self.num_local_heads * self.v_head_dim) + ..., : self.v_head_dim + ].reshape(-1, self.num_local_heads * self.v_head_dim) output, _ = self.o_proj(attn_output) return output @@ -612,7 +619,7 @@ def __init__( self.num_heads = num_heads assert num_heads % attn_tp_size == 0 self.num_local_heads = num_heads // attn_tp_size - self.scaling = self.qk_head_dim ** -0.5 + self.scaling = self.qk_head_dim**-0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings @@ -799,16 +806,16 @@ def forward_normal( kv = self.kv_b_proj(kv_a)[0] kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) k_nope = kv[..., : self.qk_nope_head_dim] - v = kv[..., self.qk_nope_head_dim:] - k_pe = latent_cache[:, :, self.kv_lora_rank:] + v = kv[..., self.qk_nope_head_dim :] + k_pe = latent_cache[:, :, self.kv_lora_rank :] q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - q[..., self.qk_nope_head_dim:] = q_pe + q[..., self.qk_nope_head_dim :] = q_pe k = torch.empty_like(q) k[..., : self.qk_nope_head_dim] = k_nope - k[..., self.qk_nope_head_dim:] = k_pe + k[..., self.qk_nope_head_dim :] = k_pe latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1) - latent_cache[:, :, self.kv_lora_rank:] = k_pe + latent_cache[:, :, self.kv_lora_rank :] = k_pe # Save latent cache forward_batch.token_to_kv_pool.set_kv_buffer( @@ -861,11 +868,11 @@ def forward_absorb( v_input = self.kv_a_layernorm(v_input.contiguous()).unsqueeze(1) k_input = latent_cache.unsqueeze(1) k_input[..., : self.kv_lora_rank] = v_input - k_pe = k_input[..., self.kv_lora_rank:] + k_pe = k_input[..., self.kv_lora_rank :] q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - q_input[..., self.kv_lora_rank:] = q_pe - k_input[..., self.kv_lora_rank:] = k_pe + q_input[..., self.kv_lora_rank :] = q_pe + k_input[..., self.kv_lora_rank :] = k_pe attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch) attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank) @@ -941,15 +948,15 @@ def forward_absorb_fused_mla_rope( k_input[..., : self.kv_lora_rank] = v_input if not enable_rope_fusion: - k_pe = k_input[..., self.kv_lora_rank:] + k_pe = k_input[..., self.kv_lora_rank :] q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - q_input[..., self.kv_lora_rank:] = q_pe - k_input[..., self.kv_lora_rank:] = k_pe + q_input[..., self.kv_lora_rank :] = q_pe + k_input[..., self.kv_lora_rank :] = k_pe k_pe_output = None else: - k_pe_output = torch.empty_like(k_input[..., self.kv_lora_rank:]) + k_pe_output = torch.empty_like(k_input[..., self.kv_lora_rank :]) - q_input[..., self.kv_lora_rank:] = q_pe + q_input[..., self.kv_lora_rank :] = q_pe # attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch) # Use Fused ROPE with use_rope=OFF. @@ -1006,7 +1013,7 @@ def forward_absorb_fused_mla_rope( ) if enable_rope_fusion: - k_input[..., self.kv_lora_rank:] = k_pe_output + k_input[..., self.kv_lora_rank :] = k_pe_output forward_batch.token_to_kv_pool.set_kv_buffer( self.attn_mqa, forward_batch.out_cache_loc, k_input, None ) @@ -1165,7 +1172,7 @@ def _compute_info(config: PretrainedConfig, layer_id: int, is_nextn: bool): execution_mode = ( _DecoderLayerExecutionMode.MLP_INPUT_ONE if (global_server_args_dict["enable_deepep_moe"] and is_sparse) - or (DeepseekV2DecoderLayer._enable_moe_dense_fully_dp() and not is_sparse) + or (DeepseekV2DecoderLayer._enable_moe_dense_fully_dp() and not is_sparse) else _DecoderLayerExecutionMode.MLP_INPUT_ALL ) return _DecoderLayerInfo(is_sparse=is_sparse, execution_mode=execution_mode) @@ -1571,7 +1578,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): self.config.moe_layer_freq, ), desc=f"Cloning {self.n_share_experts_fusion} " - "replicas of the shared expert into MoE", + "replicas of the shared expert into MoE", ): for num_repeat in range(self.n_share_experts_fusion): for suffix in suffix_list: @@ -1602,11 +1609,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", num_experts=self.config.n_routed_experts - + ( - self.n_share_experts_fusion - if self.n_share_experts_fusion is not None - else 0 - ), + + ( + self.n_share_experts_fusion + if self.n_share_experts_fusion is not None + else 0 + ), ) params_dict = dict(self.named_parameters()) diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 00045645f38..b9a11001c0f 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -44,7 +44,9 @@ ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.managers.expert_distribution import get_global_expert_distribution_recorder +from sglang.srt.managers.expert_distribution import ( + get_global_expert_distribution_recorder, +) from sglang.srt.managers.expert_location import ( ExpertLocationMetadata, ModelConfigForExpertLocation, diff --git a/python/sglang/srt/poll_based_barrier.py b/python/sglang/srt/poll_based_barrier.py index 2f30fffddc8..db1d22763c8 100644 --- a/python/sglang/srt/poll_based_barrier.py +++ b/python/sglang/srt/poll_based_barrier.py @@ -23,5 +23,9 @@ def _compute_global_arrived(self) -> bool: local_arrived = self._noop or self._local_arrived global_arrived = torch.tensor(local_arrived) # Can optimize if bottleneck - torch.distributed.all_reduce(global_arrived, torch.distributed.ReduceOp.MIN, group=get_world_group().cpu_group) + torch.distributed.all_reduce( + global_arrived, + torch.distributed.ReduceOp.MIN, + group=get_world_group().cpu_group, + ) return global_arrived.item() diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 68a8e005a4e..b9c1939a473 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -165,7 +165,9 @@ class ServerArgs: ep_num_redundant_experts: int = 0 ep_dispatch_algorithm: Optional[Literal["static", "random"]] = None init_expert_location: Optional[str] = None - expert_location_updater_mode: Optional[Literal["pin_memory", "pageable_memory"]] = None + expert_location_updater_mode: Optional[Literal["pin_memory", "pageable_memory"]] = ( + None + ) enable_eplb: bool = False eplb_storage_dir: str = "/tmp/eplb_storage" eplb_rebalance_period: Optional[int] = None @@ -341,15 +343,19 @@ def __post_init__(self): if self.expert_location_updater_mode is None: self.expert_location_updater_mode = "pin_memory" logger.info( - f"EPLB is enabled. The enable_expert_distribution_recorder and expert_location_updater_mode are automatically set.") + f"EPLB is enabled. The enable_expert_distribution_recorder and expert_location_updater_mode are automatically set." + ) if self.expert_location_updater_mode is not None: self.disable_overlap_schedule = True logger.info( - f"ExpertLocationUpdater is enabled. The disable_overlap_schedule is set.") + f"ExpertLocationUpdater is enabled. The disable_overlap_schedule is set." + ) if self.enable_eplb or (self.init_expert_location is not None): if self.ep_dispatch_algorithm is None: self.ep_dispatch_algorithm = "static" - logger.info(f"EPLB is enabled or init_expert_location is provided. ep_dispatch_algorithm is configured.") + logger.info( + f"EPLB is enabled or init_expert_location is provided. ep_dispatch_algorithm is configured." + ) if self.ep_num_redundant_experts > 0: assert ( @@ -449,8 +455,8 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.tokenizer_mode, choices=["auto", "slow"], help="Tokenizer mode. 'auto' will use the fast " - "tokenizer if available, and 'slow' will " - "always use the slow tokenizer.", + "tokenizer if available, and 'slow' will " + "always use the slow tokenizer.", ) parser.add_argument( "--skip-tokenizer-init", @@ -474,21 +480,21 @@ def add_cli_args(parser: argparse.ArgumentParser): "remote", ], help="The format of the model weights to load. " - '"auto" will try to load the weights in the safetensors format ' - "and fall back to the pytorch bin format if safetensors format " - "is not available. " - '"pt" will load the weights in the pytorch bin format. ' - '"safetensors" will load the weights in the safetensors format. ' - '"npcache" will load the weights in pytorch format and store ' - "a numpy cache to speed up the loading. " - '"dummy" will initialize the weights with random values, ' - "which is mainly for profiling." - '"gguf" will load the weights in the gguf format. ' - '"bitsandbytes" will load the weights using bitsandbytes ' - "quantization." - '"layered" loads weights layer by layer so that one can quantize a ' - "layer before loading another to make the peak memory envelope " - "smaller.", + '"auto" will try to load the weights in the safetensors format ' + "and fall back to the pytorch bin format if safetensors format " + "is not available. " + '"pt" will load the weights in the pytorch bin format. ' + '"safetensors" will load the weights in the safetensors format. ' + '"npcache" will load the weights in pytorch format and store ' + "a numpy cache to speed up the loading. " + '"dummy" will initialize the weights with random values, ' + "which is mainly for profiling." + '"gguf" will load the weights in the gguf format. ' + '"bitsandbytes" will load the weights using bitsandbytes ' + "quantization." + '"layered" loads weights layer by layer so that one can quantize a ' + "layer before loading another to make the peak memory envelope " + "smaller.", ) parser.add_argument( "--trust-remote-code", @@ -501,13 +507,13 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.dtype, choices=["auto", "half", "float16", "bfloat16", "float", "float32"], help="Data type for model weights and activations.\n\n" - '* "auto" will use FP16 precision for FP32 and FP16 models, and ' - "BF16 precision for BF16 models.\n" - '* "half" for FP16. Recommended for AWQ quantization.\n' - '* "float16" is the same as "half".\n' - '* "bfloat16" for a balance between precision and range.\n' - '* "float" is shorthand for FP32 precision.\n' - '* "float32" for FP32 precision.', + '* "auto" will use FP16 precision for FP32 and FP16 models, and ' + "BF16 precision for BF16 models.\n" + '* "half" for FP16. Recommended for AWQ quantization.\n' + '* "float16" is the same as "half".\n' + '* "bfloat16" for a balance between precision and range.\n' + '* "float" is shorthand for FP32 precision.\n' + '* "float32" for FP32 precision.', ) parser.add_argument( "--kv-cache-dtype", @@ -542,9 +548,9 @@ def add_cli_args(parser: argparse.ArgumentParser): type=nullable_str, default=None, help="Path to the JSON file containing the KV cache " - "scaling factors. This should generally be supplied, when " - "KV cache dtype is FP8. Otherwise, KV cache scaling factors " - "default to 1.0, which may cause accuracy issues. ", + "scaling factors. This should generally be supplied, when " + "KV cache dtype is FP8. Otherwise, KV cache scaling factors " + "default to 1.0, which may cause accuracy issues. ", ) parser.add_argument( "--context-length", @@ -586,8 +592,8 @@ def add_cli_args(parser: argparse.ArgumentParser): type=str, default=None, help="The specific model version to use. It can be a branch " - "name, a tag name, or a commit id. If unspecified, will use " - "the default version.", + "name, a tag name, or a commit id. If unspecified, will use " + "the default version.", ) # Memory and scheduling parser.add_argument( @@ -607,7 +613,7 @@ def add_cli_args(parser: argparse.ArgumentParser): type=int, default=ServerArgs.max_total_tokens, help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. " - "This option is typically used for development and debugging purposes.", + "This option is typically used for development and debugging purposes.", ) parser.add_argument( "--chunked-prefill-size", @@ -1077,7 +1083,7 @@ def add_cli_args(parser: argparse.ArgumentParser): "--triton-attention-reduce-in-fp32", action="store_true", help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16." - "This only affects Triton attention kernels.", + "This only affects Triton attention kernels.", ) parser.add_argument( "--triton-attention-num-kv-splits", @@ -1090,8 +1096,8 @@ def add_cli_args(parser: argparse.ArgumentParser): type=int, default=ServerArgs.num_continuous_decode_steps, help="Run multiple continuous decoding steps to reduce scheduling overhead. " - "This can potentially increase throughput but may also increase time-to-first-token latency. " - "The default value is 1, meaning only run one decoding step at a time.", + "This can potentially increase throughput but may also increase time-to-first-token latency. " + "The default value is 1, meaning only run one decoding step at a time.", ) parser.add_argument( "--delete-ckpt-after-loading", @@ -1197,7 +1203,7 @@ def add_cli_args(parser: argparse.ArgumentParser): type=int, default=0, help="The number of shared_experts need to be replica to fuse with normal experts in deepseek v3/r1 " - "we use tp_size by default.", + "we use tp_size by default.", ) parser.add_argument( "--disable-shared-experts-fusion", @@ -1216,7 +1222,7 @@ def add_cli_args(parser: argparse.ArgumentParser): type=str, required=False, help="Specify custom warmup functions (csv) to run before server starts eg. --warmups=warmup_name1,warmup_name2 " - "will run the functions `warmup_name1` and `warmup_name2` specified in warmup.py before the server starts listening for requests", + "will run the functions `warmup_name1` and `warmup_name2` specified in warmup.py before the server starts listening for requests", ) # Debug tensor dumps diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index ca2f4632525..64f347eb0e5 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -196,7 +196,7 @@ def auto_partition(files, rank, size): if args.auto_partition_size: files = auto_partition(files, args.auto_partition_id, args.auto_partition_size) else: - files = files[args.range_begin: args.range_end] + files = files[args.range_begin : args.range_end] print("The running tests are ", [f.name for f in files]) diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index 0a76bb37571..ce7c6ae65c8 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -8,8 +8,9 @@ from typing import List import numpy as np -import sglang as sgl import torch + +import sglang as sgl from sglang.srt.managers.expert_distribution_storage import ExpertDistributionStorage from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( @@ -42,23 +43,23 @@ def _test_eplb_many_rebalances_core(self, enable_eplb: bool = True): contents_raw = [ dict( prompt="1+1=2, 1+2=3, 1+3=4, 1+4=5, 1+5=", - expect_output='6, 1', + expect_output="6, 1", ), dict( prompt="2*1=2, 2*2=4, 2*3=6, 2*4=", - expect_output='8, 2', + expect_output="8, 2", ), dict( prompt="10*1=10, 10*2=20, 10*3=30, 10*4=40, 10*5=50, 10*6=", - expect_output='60, ', + expect_output="60, ", ), dict( prompt="2/2=1, 4/2=2, 6/2=3, 8/2=", - expect_output='4, 1', + expect_output="4, 1", ), dict( prompt="One plus one is two, one plus two is three, one plus three is", - expect_output=' four, one plus', + expect_output=" four, one plus", ), ] @@ -73,12 +74,18 @@ async def _task_generate(): random.shuffle(contents_duplicated) tasks = [] - async for content in _yield_with_poisson_process(contents_duplicated, action_rate=request_rate): + async for content in _yield_with_poisson_process( + contents_duplicated, action_rate=request_rate + ): print(f"[{time.time()}] Action: start async_generate") - tasks.append(asyncio.create_task(engine.async_generate( - prompt=content["prompt"], - sampling_params=dict(temperature=0, max_new_tokens=4), - ))) + tasks.append( + asyncio.create_task( + engine.async_generate( + prompt=content["prompt"], + sampling_params=dict(temperature=0, max_new_tokens=4), + ) + ) + ) actual_outputs = await asyncio.gather(*tasks) actual_output_texts = [x["text"] for x in actual_outputs] @@ -128,12 +135,18 @@ async def _task_rebalance(): del engine def test_eplb_start_rebalance_restart_mode_pin_memory(self): - self._test_eplb_start_rebalance_restart_core(expert_location_updater_mode="pin_memory") + self._test_eplb_start_rebalance_restart_core( + expert_location_updater_mode="pin_memory" + ) def test_eplb_start_rebalance_restart_mode_pageable_memory(self): - self._test_eplb_start_rebalance_restart_core(expert_location_updater_mode="pageable_memory") + self._test_eplb_start_rebalance_restart_core( + expert_location_updater_mode="pageable_memory" + ) - def _test_eplb_start_rebalance_restart_core(self, expert_location_updater_mode: str): + def _test_eplb_start_rebalance_restart_core( + self, expert_location_updater_mode: str + ): print("Action: test_eplb_start_rebalance_restart") with tempfile.TemporaryDirectory() as tmpdir: engine_kwargs = dict( @@ -255,11 +268,11 @@ def test_nontrivial_location(self): offset = 3 physical_to_logical_map = ( - offset - + torch.arange(0, _NUM_ROUTED_EXPERTS + ep_num_redundant_experts).repeat( - _NUM_HIDDEN_LAYERS, 1 - ) - ) % _NUM_ROUTED_EXPERTS + offset + + torch.arange(0, _NUM_ROUTED_EXPERTS + ep_num_redundant_experts).repeat( + _NUM_HIDDEN_LAYERS, 1 + ) + ) % _NUM_ROUTED_EXPERTS init_expert_location = dict( physical_to_logical_map=physical_to_logical_map.tolist() ) diff --git a/test/srt/test_expert_distribution.py b/test/srt/test_expert_distribution.py index a1a4daf2cc4..9deb084264d 100755 --- a/test/srt/test_expert_distribution.py +++ b/test/srt/test_expert_distribution.py @@ -3,6 +3,7 @@ import requests import torch + from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST, diff --git a/test/srt/test_expert_location.py b/test/srt/test_expert_location.py index 6ed81dfae83..8a4a0db304a 100644 --- a/test/srt/test_expert_location.py +++ b/test/srt/test_expert_location.py @@ -1,7 +1,10 @@ import unittest import torch -from sglang.srt.managers.expert_location import compute_logical_to_rank_dispatch_physical_map + +from sglang.srt.managers.expert_location import ( + compute_logical_to_rank_dispatch_physical_map, +) from sglang.test.test_utils import CustomTestCase @@ -21,30 +24,87 @@ def test_compute_logical_to_rank_dispatch_physical_map(self): ), # Identity map + consider redundant experts ( - [[[0, 8], [1, 9], [2, 10], [3, 11], [4, -1], [5, -1], [6, -1], [7, -1]]], - [[[0, 1, 2, 11, 4, 5, 6, 7]], [[8, 9, 2, 3, 4, 5, 6, 7]], - [[8, 1, 10, 3, 4, 5, 6, 7]], [[0, 9, 10, 11, 4, 5, 6, 7]]], + [ + [ + [0, 8], + [1, 9], + [2, 10], + [3, 11], + [4, -1], + [5, -1], + [6, -1], + [7, -1], + ] + ], + [ + [[0, 1, 2, 11, 4, 5, 6, 7]], + [[8, 9, 2, 3, 4, 5, 6, 7]], + [[8, 1, 10, 3, 4, 5, 6, 7]], + [[0, 9, 10, 11, 4, 5, 6, 7]], + ], ), # One logical expert is put on ALL gpus ( - [[[3, 9, 6, 0], [1, -1, -1, -1], [2, -1, -1, -1], [4, -1, -1, -1], [5, -1, -1, -1], [7, -1, -1, -1], - [8, -1, -1, -1], [10, -1, -1, -1]]], - [[[0, 1, 2, 4, 5, 7, 8, 10]], [[3, 1, 2, 4, 5, 7, 8, 10]], [[6, 1, 2, 4, 5, 7, 8, 10]], - [[9, 1, 2, 4, 5, 7, 8, 10]]], + [ + [ + [3, 9, 6, 0], + [1, -1, -1, -1], + [2, -1, -1, -1], + [4, -1, -1, -1], + [5, -1, -1, -1], + [7, -1, -1, -1], + [8, -1, -1, -1], + [10, -1, -1, -1], + ] + ], + [ + [[0, 1, 2, 4, 5, 7, 8, 10]], + [[3, 1, 2, 4, 5, 7, 8, 10]], + [[6, 1, 2, 4, 5, 7, 8, 10]], + [[9, 1, 2, 4, 5, 7, 8, 10]], + ], ), # One logical expert is put multiple times on ONE gpu ( - [[[2, 0, 1], [3, -1, -1], [4, -1, -1], [5, -1, -1], [6, -1, -1], [7, -1, -1], [8, -1, -1], - [9, -1, -1]]], - [[[2, 3, 4, 5, 6, 7, 8, 9]], [[0, 3, 4, 5, 6, 7, 8, 9]], [[2, 3, 4, 5, 6, 7, 8, 9]], - [[1, 3, 4, 5, 6, 7, 8, 9]]], + [ + [ + [2, 0, 1], + [3, -1, -1], + [4, -1, -1], + [5, -1, -1], + [6, -1, -1], + [7, -1, -1], + [8, -1, -1], + [9, -1, -1], + ] + ], + [ + [[2, 3, 4, 5, 6, 7, 8, 9]], + [[0, 3, 4, 5, 6, 7, 8, 9]], + [[2, 3, 4, 5, 6, 7, 8, 9]], + [[1, 3, 4, 5, 6, 7, 8, 9]], + ], ), # Random ( - [[[4, 11, -1], [5, 9, 0], [6, -1, -1], [8, -1, -1], [1, -1, -1], [10, -1, -1], [2, 3, -1], - [7, -1, -1]]], - [[[11, 0, 6, 8, 1, 10, 2, 7]], [[4, 5, 6, 8, 1, 10, 3, 7]], [[4, 5, 6, 8, 1, 10, 2, 7]], - [[11, 9, 6, 8, 1, 10, 3, 7]]], + [ + [ + [4, 11, -1], + [5, 9, 0], + [6, -1, -1], + [8, -1, -1], + [1, -1, -1], + [10, -1, -1], + [2, 3, -1], + [7, -1, -1], + ] + ], + [ + [[11, 0, 6, 8, 1, 10, 2, 7]], + [[4, 5, 6, 8, 1, 10, 3, 7]], + [[4, 5, 6, 8, 1, 10, 2, 7]], + [[11, 9, 6, 8, 1, 10, 3, 7]], + ], ), ] @@ -59,7 +119,9 @@ def test_compute_logical_to_rank_dispatch_physical_map(self): actual_outputs.append(actual_output) print(f"{actual_output=} {expect_output=}") - for (logical_to_all_physical_map, expect_output), actual_output in zip(cases, actual_outputs): + for (logical_to_all_physical_map, expect_output), actual_output in zip( + cases, actual_outputs + ): self.assertEqual(actual_output, expect_output) From efbf773784872c267054eb2ea2cf9e4dec1cc026 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 20:40:00 +0800 Subject: [PATCH 1069/1089] more --- .../srt/model_executor/memory_transfer.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/python/sglang/srt/model_executor/memory_transfer.py b/python/sglang/srt/model_executor/memory_transfer.py index 83e0e39e383..2979583c4db 100644 --- a/python/sglang/srt/model_executor/memory_transfer.py +++ b/python/sglang/srt/model_executor/memory_transfer.py @@ -25,9 +25,9 @@ def get_outputs(self) -> List[NamedTensors]: class CombinedManager(TensorOperationManagerBase): def __init__( - self, - manager_a: TensorOperationManagerBase, - manager_b: TensorOperationManagerBase, + self, + manager_a: TensorOperationManagerBase, + manager_b: TensorOperationManagerBase, ): # For simplicity, only support chaining 2 managers, but can be extended to N self._manager_a = manager_a @@ -120,8 +120,8 @@ def enqueue(self, named_tensors: NamedTensors): def get_outputs(self) -> List[NamedTensors]: outputs = [] while ( - len(self._inflight_tasks) > 0 - and self._inflight_tasks[0].finish_event.query() + len(self._inflight_tasks) > 0 + and self._inflight_tasks[0].finish_event.query() ): task = self._inflight_tasks.pop(0) outputs.append(self._handle_one_output(task)) @@ -137,9 +137,9 @@ def _auto_create_stream(self): @staticmethod def _tensor_to_cuda( - input_tensor: torch.Tensor, allocator: "SimpleCachingAllocator" + input_tensor: torch.Tensor, allocator: "SimpleCachingAllocator" ): - output_tensor = allocator.allocate(input_tensor.size, input_tensor.dtype) + output_tensor = allocator.allocate(input_tensor.size(), input_tensor.dtype) output_tensor.copy_(input_tensor, non_blocking=True) return output_tensor @@ -160,11 +160,13 @@ def __init__(self, device): self._used_pool: List[torch.Tensor] = [] def allocate(self, size, dtype) -> torch.Tensor: + size = tuple(size) + unused_pool_entry = self._unused_pool[(size, dtype)] if len(unused_pool_entry) > 0: output = unused_pool_entry.pop() else: - output = torch.empty(size, dtype=dtype, device=self._device) + output = torch.empty(*size, dtype=dtype, device=self._device) self._used_pool.append(output) From c702f6c805a5c38f6d8e5a2405f5890e438fbba6 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 20:55:46 +0800 Subject: [PATCH 1070/1089] more --- python/sglang/srt/model_executor/memory_transfer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/model_executor/memory_transfer.py b/python/sglang/srt/model_executor/memory_transfer.py index 2979583c4db..3ef475ab865 100644 --- a/python/sglang/srt/model_executor/memory_transfer.py +++ b/python/sglang/srt/model_executor/memory_transfer.py @@ -174,6 +174,6 @@ def allocate(self, size, dtype) -> torch.Tensor: def mark_all_unused(self): for tensor in self._used_pool: - self._unused_pool[(tensor.size, tensor.dtype)].append(tensor) + self._unused_pool[(tuple(tensor.size()), tensor.dtype)].append(tensor) self._used_pool.clear() From dc76f5c60f925e109898918631740548baf7f3ca Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 21:09:00 +0800 Subject: [PATCH 1071/1089] more --- python/sglang/srt/managers/tokenizer_manager.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 7b8af1fbb40..38356fb3c89 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -17,6 +17,7 @@ import copy import dataclasses import logging +import math import os import pickle import signal @@ -45,7 +46,6 @@ import zmq import zmq.asyncio from fastapi import BackgroundTasks - from sglang.srt.aio_rwlock import RWLock from sglang.srt.configs.model_config import ModelConfig from sglang.srt.disaggregation.conn import KVBootstrapServer @@ -713,7 +713,8 @@ async def update_expert_location(self, obj: UpdateExpertLocationReqInput): num_layers = old_expert_location_metadata.num_layers # pretty arbitrary choice; can optimize if bottleneck - layer_id_lens = list(range(10, num_layers, 10)) + [num_layers] + step = math.ceil(num_layers / 5) + layer_id_lens = list(range(step, num_layers, step)) + [num_layers] for layer_id_end in layer_id_lens: logger.info(f"update_expert_location handling up to {layer_id_end}th layer") @@ -1047,8 +1048,8 @@ def _handle_batch_output( elif isinstance(recv_obj, BatchTokenIDOut): if self.server_args.stream_output and state.obj.stream: output_token_ids = recv_obj.output_ids[i][ - state.last_output_offset : - ] + state.last_output_offset: + ] state.last_output_offset = len(recv_obj.output_ids[i]) else: output_token_ids = recv_obj.output_ids[i] From a1c895a4a1ae1aa2eda6bdeb7ef80303291b22c2 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 21:33:37 +0800 Subject: [PATCH 1072/1089] more --- python/sglang/srt/entrypoints/http_server.py | 4 ++-- python/sglang/srt/managers/io_struct.py | 5 +++++ python/sglang/srt/managers/tokenizer_manager.py | 4 ++-- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 9ddf72cad99..37e396d583e 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -371,8 +371,8 @@ async def dump_expert_distribution_record_async(): @app.post("/eplb_rebalance") -async def eplb_rebalance(): - await _global_state.tokenizer_manager.eplb_rebalance() +async def eplb_rebalance(obj: EplbRebalanceReqInput): + await _global_state.tokenizer_manager.eplb_rebalance(obj) return ORJSONResponse({}, status_code=200) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 8c67b9f189c..2c3dfe13be1 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -677,6 +677,11 @@ class FlushCacheReqOutput: success: bool +@dataclass +class EplbRebalanceReqInput: + debug_use_random_stat: bool = False + + @dataclass class UpdateExpertLocationReqInput: expert_location_metadata: "ExpertLocationMetadata" diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 38356fb3c89..6e3d9fb2be1 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -96,7 +96,7 @@ UpdateWeightsFromDistributedReqInput, UpdateWeightsFromDistributedReqOutput, UpdateWeightsFromTensorReqInput, - UpdateWeightsFromTensorReqOutput, + UpdateWeightsFromTensorReqOutput, EplbRebalanceReqInput, ) from sglang.srt.managers.multimodal_processor import ( get_dummy_processor, @@ -695,7 +695,7 @@ async def dump_expert_distribution_record(self): expert_location_metadata=self.expert_location_metadata, ) - async def eplb_rebalance(self): + async def eplb_rebalance(self, obj: EplbRebalanceReqInput): self.auto_create_handle_loop() await self.eplb_manager.rebalance() From ba4d44949e5121320f0067a473068847e6ea52cb Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 21:34:04 +0800 Subject: [PATCH 1073/1089] more --- python/sglang/srt/managers/tokenizer_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 6e3d9fb2be1..ff0854ef21a 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -697,7 +697,7 @@ async def dump_expert_distribution_record(self): async def eplb_rebalance(self, obj: EplbRebalanceReqInput): self.auto_create_handle_loop() - await self.eplb_manager.rebalance() + await self.eplb_manager.rebalance(obj) async def eplb_save_expert_distribution(self): self.auto_create_handle_loop() From 5b0f321045599a5e419a89c0da68ab7631217542 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 21:34:21 +0800 Subject: [PATCH 1074/1089] more --- python/sglang/srt/entrypoints/engine.py | 4 ++-- python/sglang/srt/managers/eplb_manager.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index b98c37fa9e3..ab091fc59cf 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -61,7 +61,7 @@ UpdateExpertLocationReqInput, UpdateWeightFromDiskReqInput, UpdateWeightsFromDistributedReqInput, - UpdateWeightsFromTensorReqInput, + UpdateWeightsFromTensorReqInput, EplbRebalanceReqInput, ) from sglang.srt.managers.scheduler import run_scheduler_process from sglang.srt.managers.tokenizer_manager import TokenizerManager @@ -368,7 +368,7 @@ def update_weights_from_tensor( def eplb_rebalance(self): loop = asyncio.get_event_loop() - return loop.run_until_complete(self.tokenizer_manager.eplb_rebalance()) + return loop.run_until_complete(self.tokenizer_manager.eplb_rebalance(EplbRebalanceReqInput())) def eplb_save_expert_distribution(self): loop = asyncio.get_event_loop() diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index 36f842e0b67..b54ad8fea31 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -10,7 +10,7 @@ ExpertLocationMetadata, ModelConfigForExpertLocation, ) -from sglang.srt.managers.io_struct import UpdateExpertLocationReqInput +from sglang.srt.managers.io_struct import UpdateExpertLocationReqInput, EplbRebalanceReqInput from sglang.srt.server_args import ServerArgs if TYPE_CHECKING: @@ -40,9 +40,9 @@ async def handle_loop(self): f"EPLBManager: Sleep {sleep_time} seconds before next automatic rebalancing" ) await asyncio.sleep(sleep_time) - await self.rebalance() + await self.rebalance(EplbRebalanceReqInput()) - async def rebalance(self): + async def rebalance(self, obj: EplbRebalanceReqInput): await self.save_expert_distribution() expert_location_metadata = self.compute_expert_location_metadata() await self._tokenizer_manager.update_expert_location( From 552b7f5ae8b70e4ff280cc3d16f836df9708dd6f Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 21:34:59 +0800 Subject: [PATCH 1075/1089] more --- python/sglang/srt/managers/eplb_manager.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index b54ad8fea31..65e18039ef0 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -25,7 +25,7 @@ def __init__(self, server_args: ServerArgs): self._server_args = server_args self._expert_distribution_storage = ExpertDistributionStorage( dir_data=Path(self._server_args.eplb_storage_dir) - / "expert_distribution_storage" + / "expert_distribution_storage" ) def bind(self, tokenizer_manager: "TokenizerManager"): @@ -44,7 +44,8 @@ async def handle_loop(self): async def rebalance(self, obj: EplbRebalanceReqInput): await self.save_expert_distribution() - expert_location_metadata = self.compute_expert_location_metadata() + expert_location_metadata = self.compute_expert_location_metadata( + debug_use_random_stat=obj.debug_use_random_stat) await self._tokenizer_manager.update_expert_location( UpdateExpertLocationReqInput( expert_location_metadata=expert_location_metadata @@ -54,7 +55,7 @@ async def rebalance(self, obj: EplbRebalanceReqInput): async def save_expert_distribution(self): await self._expert_distribution_storage.save_current() - def compute_expert_location_metadata(self): + def compute_expert_location_metadata(self, debug_use_random_stat: bool = False): snapshot = self._expert_distribution_storage.get_last_snapshot() if snapshot is None: return ExpertLocationMetadata.init_trivial(self._server_args) From 5fa7fa8dde60f790eb2d81901d4998d1257ac209 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 21:35:40 +0800 Subject: [PATCH 1076/1089] more --- python/sglang/srt/managers/eplb_manager.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index 65e18039ef0..d3baa4c4a42 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -59,4 +59,9 @@ def compute_expert_location_metadata(self, debug_use_random_stat: bool = False): snapshot = self._expert_distribution_storage.get_last_snapshot() if snapshot is None: return ExpertLocationMetadata.init_trivial(self._server_args) + + if debug_use_random_stat: + logger.warning("EPLBManager.compute_expert_location_metadata use random stat for debugging.") + snapshot = {"logical_count": TODO} + return ExpertLocationMetadata.init_by_eplb(self._server_args, **snapshot) From dd910a4027168ceb823829c55052d37f0f18269c Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 21:36:27 +0800 Subject: [PATCH 1077/1089] more --- python/sglang/srt/managers/eplb_manager.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index d3baa4c4a42..9d865585e61 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -3,6 +3,7 @@ from pathlib import Path from typing import TYPE_CHECKING +import torch from sglang.srt.configs.model_config import ModelConfig from sglang.srt.managers import deepseek_eplb from sglang.srt.managers.expert_distribution_storage import ExpertDistributionStorage @@ -62,6 +63,7 @@ def compute_expert_location_metadata(self, debug_use_random_stat: bool = False): if debug_use_random_stat: logger.warning("EPLBManager.compute_expert_location_metadata use random stat for debugging.") - snapshot = {"logical_count": TODO} + original_logical_count = torch.tensor(snapshot["logical_count"]) + snapshot = {"logical_count": torch.randint_like(original_logical_count, high=100000)} return ExpertLocationMetadata.init_by_eplb(self._server_args, **snapshot) From 6b19c68b0156f75bce7fff0d3ebc07fb4af84103 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 21:37:34 +0800 Subject: [PATCH 1078/1089] more --- python/sglang/srt/entrypoints/http_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 37e396d583e..64210561153 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -60,7 +60,7 @@ SetInternalStateReq, UpdateWeightFromDiskReqInput, UpdateWeightsFromDistributedReqInput, - VertexGenerateReqInput, + VertexGenerateReqInput, EplbRebalanceReqInput, ) from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.metrics.func_timer import enable_func_timer From 6b52b6d8200247fc43b009a91af3e49eb6788a7b Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 17 Apr 2025 21:41:16 +0800 Subject: [PATCH 1079/1089] more --- python/sglang/srt/entrypoints/http_server.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 64210561153..288c3257136 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -313,7 +313,7 @@ async def flush_cache(): ret = await _global_state.tokenizer_manager.flush_cache() return Response( content="Cache flushed.\nPlease check backend logs for more details. " - "(When there are running or waiting requests, the operation will not be performed.)\n", + "(When there are running or waiting requests, the operation will not be performed.)\n", status_code=200 if ret.success else HTTPStatus.BAD_REQUEST, ) @@ -371,8 +371,8 @@ async def dump_expert_distribution_record_async(): @app.post("/eplb_rebalance") -async def eplb_rebalance(obj: EplbRebalanceReqInput): - await _global_state.tokenizer_manager.eplb_rebalance(obj) +async def eplb_rebalance(obj: Optional[EplbRebalanceReqInput] = None): + await _global_state.tokenizer_manager.eplb_rebalance(obj or EplbRebalanceReqInput()) return ORJSONResponse({}, status_code=200) @@ -646,10 +646,10 @@ async def vertex_generate(vertex_req: VertexGenerateReqInput, raw_request: Reque ] break image_data = [ - instance.get("image_data") - for instance in vertex_req.instances - if instance.get("image_data") is not None - ] or None + instance.get("image_data") + for instance in vertex_req.instances + if instance.get("image_data") is not None + ] or None req = GenerateReqInput( **inputs, image_data=image_data, From 518ca32e29e00aefd9b9b3e9f27fdfa135328e4f Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 18 Apr 2025 10:08:20 +0800 Subject: [PATCH 1080/1089] more --- python/sglang/srt/server_args.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index b9c1939a473..ee33e3e7e12 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -341,7 +341,7 @@ def __post_init__(self): if self.enable_eplb: self.enable_expert_distribution_recorder = True if self.expert_location_updater_mode is None: - self.expert_location_updater_mode = "pin_memory" + self.expert_location_updater_mode = "pageable_memory" logger.info( f"EPLB is enabled. The enable_expert_distribution_recorder and expert_location_updater_mode are automatically set." ) From 02c622f71c3490c802ec38bb5e99c2255e87ab80 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 18 Apr 2025 10:09:06 +0800 Subject: [PATCH 1081/1089] fmt --- python/sglang/srt/entrypoints/engine.py | 7 +++++-- python/sglang/srt/entrypoints/http_server.py | 13 +++++++------ python/sglang/srt/managers/eplb_manager.py | 19 ++++++++++++++----- .../sglang/srt/managers/tokenizer_manager.py | 8 +++++--- .../srt/model_executor/memory_transfer.py | 12 ++++++------ 5 files changed, 37 insertions(+), 22 deletions(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index ab091fc59cf..d4c60544297 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -51,6 +51,7 @@ from sglang.srt.managers.detokenizer_manager import run_detokenizer_process from sglang.srt.managers.io_struct import ( EmbeddingReqInput, + EplbRebalanceReqInput, GenerateReqInput, GetWeightsByNameReqInput, InitWeightsUpdateGroupReqInput, @@ -61,7 +62,7 @@ UpdateExpertLocationReqInput, UpdateWeightFromDiskReqInput, UpdateWeightsFromDistributedReqInput, - UpdateWeightsFromTensorReqInput, EplbRebalanceReqInput, + UpdateWeightsFromTensorReqInput, ) from sglang.srt.managers.scheduler import run_scheduler_process from sglang.srt.managers.tokenizer_manager import TokenizerManager @@ -368,7 +369,9 @@ def update_weights_from_tensor( def eplb_rebalance(self): loop = asyncio.get_event_loop() - return loop.run_until_complete(self.tokenizer_manager.eplb_rebalance(EplbRebalanceReqInput())) + return loop.run_until_complete( + self.tokenizer_manager.eplb_rebalance(EplbRebalanceReqInput()) + ) def eplb_save_expert_distribution(self): loop = asyncio.get_event_loop() diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 288c3257136..cedb0effc41 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -48,6 +48,7 @@ CloseSessionReqInput, ConfigureLoggingReq, EmbeddingReqInput, + EplbRebalanceReqInput, GenerateReqInput, GetWeightsByNameReqInput, InitWeightsUpdateGroupReqInput, @@ -60,7 +61,7 @@ SetInternalStateReq, UpdateWeightFromDiskReqInput, UpdateWeightsFromDistributedReqInput, - VertexGenerateReqInput, EplbRebalanceReqInput, + VertexGenerateReqInput, ) from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.metrics.func_timer import enable_func_timer @@ -313,7 +314,7 @@ async def flush_cache(): ret = await _global_state.tokenizer_manager.flush_cache() return Response( content="Cache flushed.\nPlease check backend logs for more details. " - "(When there are running or waiting requests, the operation will not be performed.)\n", + "(When there are running or waiting requests, the operation will not be performed.)\n", status_code=200 if ret.success else HTTPStatus.BAD_REQUEST, ) @@ -646,10 +647,10 @@ async def vertex_generate(vertex_req: VertexGenerateReqInput, raw_request: Reque ] break image_data = [ - instance.get("image_data") - for instance in vertex_req.instances - if instance.get("image_data") is not None - ] or None + instance.get("image_data") + for instance in vertex_req.instances + if instance.get("image_data") is not None + ] or None req = GenerateReqInput( **inputs, image_data=image_data, diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index 9d865585e61..6e9506c1c43 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING import torch + from sglang.srt.configs.model_config import ModelConfig from sglang.srt.managers import deepseek_eplb from sglang.srt.managers.expert_distribution_storage import ExpertDistributionStorage @@ -11,7 +12,10 @@ ExpertLocationMetadata, ModelConfigForExpertLocation, ) -from sglang.srt.managers.io_struct import UpdateExpertLocationReqInput, EplbRebalanceReqInput +from sglang.srt.managers.io_struct import ( + EplbRebalanceReqInput, + UpdateExpertLocationReqInput, +) from sglang.srt.server_args import ServerArgs if TYPE_CHECKING: @@ -26,7 +30,7 @@ def __init__(self, server_args: ServerArgs): self._server_args = server_args self._expert_distribution_storage = ExpertDistributionStorage( dir_data=Path(self._server_args.eplb_storage_dir) - / "expert_distribution_storage" + / "expert_distribution_storage" ) def bind(self, tokenizer_manager: "TokenizerManager"): @@ -46,7 +50,8 @@ async def handle_loop(self): async def rebalance(self, obj: EplbRebalanceReqInput): await self.save_expert_distribution() expert_location_metadata = self.compute_expert_location_metadata( - debug_use_random_stat=obj.debug_use_random_stat) + debug_use_random_stat=obj.debug_use_random_stat + ) await self._tokenizer_manager.update_expert_location( UpdateExpertLocationReqInput( expert_location_metadata=expert_location_metadata @@ -62,8 +67,12 @@ def compute_expert_location_metadata(self, debug_use_random_stat: bool = False): return ExpertLocationMetadata.init_trivial(self._server_args) if debug_use_random_stat: - logger.warning("EPLBManager.compute_expert_location_metadata use random stat for debugging.") + logger.warning( + "EPLBManager.compute_expert_location_metadata use random stat for debugging." + ) original_logical_count = torch.tensor(snapshot["logical_count"]) - snapshot = {"logical_count": torch.randint_like(original_logical_count, high=100000)} + snapshot = { + "logical_count": torch.randint_like(original_logical_count, high=100000) + } return ExpertLocationMetadata.init_by_eplb(self._server_args, **snapshot) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index ff0854ef21a..794bc90a21b 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -46,6 +46,7 @@ import zmq import zmq.asyncio from fastapi import BackgroundTasks + from sglang.srt.aio_rwlock import RWLock from sglang.srt.configs.model_config import ModelConfig from sglang.srt.disaggregation.conn import KVBootstrapServer @@ -65,6 +66,7 @@ CloseSessionReqInput, ConfigureLoggingReq, EmbeddingReqInput, + EplbRebalanceReqInput, ExpertDistributionReq, ExpertDistributionReqOutput, FlushCacheReqInput, @@ -96,7 +98,7 @@ UpdateWeightsFromDistributedReqInput, UpdateWeightsFromDistributedReqOutput, UpdateWeightsFromTensorReqInput, - UpdateWeightsFromTensorReqOutput, EplbRebalanceReqInput, + UpdateWeightsFromTensorReqOutput, ) from sglang.srt.managers.multimodal_processor import ( get_dummy_processor, @@ -1048,8 +1050,8 @@ def _handle_batch_output( elif isinstance(recv_obj, BatchTokenIDOut): if self.server_args.stream_output and state.obj.stream: output_token_ids = recv_obj.output_ids[i][ - state.last_output_offset: - ] + state.last_output_offset : + ] state.last_output_offset = len(recv_obj.output_ids[i]) else: output_token_ids = recv_obj.output_ids[i] diff --git a/python/sglang/srt/model_executor/memory_transfer.py b/python/sglang/srt/model_executor/memory_transfer.py index 3ef475ab865..ac50a646be2 100644 --- a/python/sglang/srt/model_executor/memory_transfer.py +++ b/python/sglang/srt/model_executor/memory_transfer.py @@ -25,9 +25,9 @@ def get_outputs(self) -> List[NamedTensors]: class CombinedManager(TensorOperationManagerBase): def __init__( - self, - manager_a: TensorOperationManagerBase, - manager_b: TensorOperationManagerBase, + self, + manager_a: TensorOperationManagerBase, + manager_b: TensorOperationManagerBase, ): # For simplicity, only support chaining 2 managers, but can be extended to N self._manager_a = manager_a @@ -120,8 +120,8 @@ def enqueue(self, named_tensors: NamedTensors): def get_outputs(self) -> List[NamedTensors]: outputs = [] while ( - len(self._inflight_tasks) > 0 - and self._inflight_tasks[0].finish_event.query() + len(self._inflight_tasks) > 0 + and self._inflight_tasks[0].finish_event.query() ): task = self._inflight_tasks.pop(0) outputs.append(self._handle_one_output(task)) @@ -137,7 +137,7 @@ def _auto_create_stream(self): @staticmethod def _tensor_to_cuda( - input_tensor: torch.Tensor, allocator: "SimpleCachingAllocator" + input_tensor: torch.Tensor, allocator: "SimpleCachingAllocator" ): output_tensor = allocator.allocate(input_tensor.size(), input_tensor.dtype) output_tensor.copy_(input_tensor, non_blocking=True) From a880d021a15353d0af8c8d247c90f7cadf83d619 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 18 Apr 2025 10:32:22 +0800 Subject: [PATCH 1082/1089] more --- python/sglang/srt/managers/expert_distribution.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 55a6c6069eb..1240e361c69 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -519,7 +519,7 @@ def append( single_pass_global_physical_count: torch.Tensor, ): # Can optimize if overhead here is large - self._buffer_global_physical_count += single_pass_global_physical_count + self._buffer_global_physical_count += single_pass_global_physical_count.cpu() def reset(self): self._buffer_global_physical_count[...] = 0 From 30728d206da10264cdbdb88785970d8a48fee3f5 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 18 Apr 2025 11:00:27 +0800 Subject: [PATCH 1083/1089] more --- python/sglang/srt/model_executor/model_runner.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 14d5bdccd28..36a44cdb1ea 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -517,6 +517,8 @@ def update_expert_location_start(self, recv_req: UpdateExpertLocationReqInput): self._expert_location_updater.start(recv_req) def event_loop_step(self) -> List[Any]: + if self._expert_location_updater is None: + return [] return self._expert_location_updater.event_loop_step() def update_weights_from_disk( From a7445c4e31fa28bf5d83055c3af10d3e20c50c0a Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 18 Apr 2025 12:02:01 +0800 Subject: [PATCH 1084/1089] fix merge --- python/sglang/srt/disaggregation/decode.py | 1 + python/sglang/srt/disaggregation/prefill.py | 1 + 2 files changed, 2 insertions(+) diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index d106e42d45f..3283d7943cf 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -426,6 +426,7 @@ def event_loop_normal_disagg_decode(self): while True: recv_reqs = self.recv_requests() self.process_input_requests(recv_reqs) + self.model_runner_event_loop_step() # polling and allocating kv cache self.process_decode_queue() batch = self.get_next_disagg_decode_batch_to_run() diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index d513b13dda1..a2eae9b9ae8 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -178,6 +178,7 @@ def event_loop_normal_disagg_prefill(self): while True: recv_reqs = self.recv_requests() self.process_input_requests(recv_reqs) + self.model_runner_event_loop_step() self.waiting_queue.extend( self.disagg_prefill_pending_queue.pop_bootstrapped() ) From c61324e9475bfd3525a2b0c7f4c7ddde1aaa6194 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 18 Apr 2025 12:02:21 +0800 Subject: [PATCH 1085/1089] fmt --- python/sglang/srt/models/deepseek_v2.py | 162 ++++++++++++------------ 1 file changed, 80 insertions(+), 82 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 2dbf063db5c..169446ddcb4 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -20,8 +20,6 @@ import os import re from dataclasses import dataclass -from enum import Enum, auto -from dataclasses import dataclass from enum import Enum, IntEnum, auto from typing import Any, Dict, Iterable, Optional, Tuple @@ -56,8 +54,8 @@ ) from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE -from sglang.srt.layers.moe.expert_location_dispatch import ExpertLocationDispatchInfo from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher +from sglang.srt.layers.moe.expert_location_dispatch import ExpertLocationDispatchInfo from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.quantization.base_config import QuantizationConfig @@ -1450,85 +1448,85 @@ def post_load_weights(self, enable_mla_postprocess: bool = True): # Perform post-processing after loading weights if enable_mla_postprocess: - for layer_id in range(self.config.num_hidden_layers): - self_attn = self.model.layers[layer_id].self_attn - if hasattr(self_attn.kv_b_proj, "qweight"): - # AWQ compatible - if _is_cuda: - w = awq_dequantize( - self_attn.kv_b_proj.qweight, - self_attn.kv_b_proj.scales, - self_attn.kv_b_proj.qzeros, - ).T - else: - w = awq_dequantize( - self_attn.kv_b_proj.qweight, - self_attn.kv_b_proj.scales, - self_attn.kv_b_proj.qzeros, - 0, - 0, - 0, - ).T - else: - w = self_attn.kv_b_proj.weight - # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`. - # This may affect the accuracy of fp8 model. - if w.dtype in ( - torch.float8_e4m3fn, - torch.float8_e4m3fnuz, - ): - if hasattr(self.quant_config, "weight_block_size"): - weight_block_size = self.quant_config.weight_block_size - if weight_block_size is not None: - assert hasattr(self_attn.kv_b_proj, "weight_scale_inv") - if _is_hip: - weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( - weight=w, - weight_scale=self_attn.kv_b_proj.weight_scale_inv, - input_scale=None, - ) - else: - weight = w - weight_scale = self_attn.kv_b_proj.weight_scale_inv - - w, scale = block_quant_to_tensor_quant( - weight, weight_scale, weight_block_size - ) - self_attn.w_scale = scale - else: - weight = w - weight_scale = self_attn.kv_b_proj.weight_scale - w, scale = channel_quant_to_tensor_quant(weight, weight_scale) - self_attn.w_scale = scale - - if w.dtype == torch.int8: - if hasattr(self.quant_config, "weight_block_size"): - # block-wise int8 need it - weight_block_size = self.quant_config.weight_block_size - if weight_block_size is not None: - assert hasattr(self_attn.kv_b_proj, "weight_scale_inv") - weight = w - weight_scale = self_attn.kv_b_proj.weight_scale_inv - w = int8_block_dequant( - weight, weight_scale, weight_block_size - ).to(torch.bfloat16) - else: - # channel-wise int8 need it - w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to( - torch.bfloat16 - ) - w_kc, w_vc = w.unflatten( - 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim) - ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1) - self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2) - self_attn.w_vc = w_vc.contiguous().transpose(1, 2) - if ( - hasattr(self_attn.kv_b_proj, "weight_scale") - and self_attn.w_scale is None - ): - self_attn.w_scale = self_attn.kv_b_proj.weight_scale - if _is_hip: - self_attn.w_scale *= 2.0 + for layer_id in range(self.config.num_hidden_layers): + self_attn = self.model.layers[layer_id].self_attn + if hasattr(self_attn.kv_b_proj, "qweight"): + # AWQ compatible + if _is_cuda: + w = awq_dequantize( + self_attn.kv_b_proj.qweight, + self_attn.kv_b_proj.scales, + self_attn.kv_b_proj.qzeros, + ).T + else: + w = awq_dequantize( + self_attn.kv_b_proj.qweight, + self_attn.kv_b_proj.scales, + self_attn.kv_b_proj.qzeros, + 0, + 0, + 0, + ).T + else: + w = self_attn.kv_b_proj.weight + # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`. + # This may affect the accuracy of fp8 model. + if w.dtype in ( + torch.float8_e4m3fn, + torch.float8_e4m3fnuz, + ): + if hasattr(self.quant_config, "weight_block_size"): + weight_block_size = self.quant_config.weight_block_size + if weight_block_size is not None: + assert hasattr(self_attn.kv_b_proj, "weight_scale_inv") + if _is_hip: + weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( + weight=w, + weight_scale=self_attn.kv_b_proj.weight_scale_inv, + input_scale=None, + ) + else: + weight = w + weight_scale = self_attn.kv_b_proj.weight_scale_inv + + w, scale = block_quant_to_tensor_quant( + weight, weight_scale, weight_block_size + ) + self_attn.w_scale = scale + else: + weight = w + weight_scale = self_attn.kv_b_proj.weight_scale + w, scale = channel_quant_to_tensor_quant(weight, weight_scale) + self_attn.w_scale = scale + + if w.dtype == torch.int8: + if hasattr(self.quant_config, "weight_block_size"): + # block-wise int8 need it + weight_block_size = self.quant_config.weight_block_size + if weight_block_size is not None: + assert hasattr(self_attn.kv_b_proj, "weight_scale_inv") + weight = w + weight_scale = self_attn.kv_b_proj.weight_scale_inv + w = int8_block_dequant( + weight, weight_scale, weight_block_size + ).to(torch.bfloat16) + else: + # channel-wise int8 need it + w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to( + torch.bfloat16 + ) + w_kc, w_vc = w.unflatten( + 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim) + ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1) + self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2) + self_attn.w_vc = w_vc.contiguous().transpose(1, 2) + if ( + hasattr(self_attn.kv_b_proj, "weight_scale") + and self_attn.w_scale is None + ): + self_attn.w_scale = self_attn.kv_b_proj.weight_scale + if _is_hip: + self_attn.w_scale *= 2.0 def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ From 5ce21cdf9217b4bd1151278dc902e2e35d6b26f6 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 18 Apr 2025 12:07:52 +0800 Subject: [PATCH 1086/1089] fix merge --- python/sglang/srt/layers/moe/ep_moe/layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 2b38a657225..95d38103d20 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -42,7 +42,7 @@ from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant from sglang.srt.model_executor.forward_batch_info import ForwardMode -from sglang.srt.utils import DeepEPMode, is_hip, set_weight_attrs +from sglang.srt.utils import DeepEPMode, MaybeDisposibleTensor, is_hip, set_weight_attrs _is_hip = is_hip() From 5286fdd7364f3718b93b19a6823ca9531453d982 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 18 Apr 2025 12:09:01 +0800 Subject: [PATCH 1087/1089] fix merge --- python/sglang/srt/layers/moe/ep_moe/layer.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 95d38103d20..d67910b0a90 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -42,7 +42,14 @@ from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant from sglang.srt.model_executor.forward_batch_info import ForwardMode -from sglang.srt.utils import DeepEPMode, MaybeDisposibleTensor, is_hip, set_weight_attrs +from sglang.srt.utils import ( + DeepEPMode, + DisposibleTensor, + MaybeDisposibleTensor, + TensorCreator, + is_hip, + set_weight_attrs, +) _is_hip = is_hip() From fb773702afdb57b2af0269872a65f346ae3c79aa Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 18 Apr 2025 12:11:49 +0800 Subject: [PATCH 1088/1089] fix merge --- python/sglang/srt/models/deepseek_v2.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 169446ddcb4..f45cdaa7a73 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -111,7 +111,6 @@ class AttnForwardMethod(IntEnum): - # Use multi-head attention MHA = auto() @@ -1044,19 +1043,6 @@ class _DecoderLayerInfo: ffn_input_mode: _FFNInputMode -class _DecoderLayerExecutionMode(Enum): - # The MLP sublayer requires 1/tp_size tokens as input - MLP_INPUT_ONE = auto() - # The MLP sublayer requires all tokens as input - MLP_INPUT_ALL = auto() - - -@dataclass -class _DecoderLayerInfo: - is_sparse: bool - execution_mode: _DecoderLayerExecutionMode - - class DeepseekV2DecoderLayer(nn.Module): def __init__( From f100686739e9971832fd79f63790ce2e6c24f9ff Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 18 Apr 2025 12:15:07 +0800 Subject: [PATCH 1089/1089] fix merge --- python/sglang/srt/layers/moe/topk.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index fbb66b71a38..305d4562330 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -236,8 +236,9 @@ def biased_grouped_topk( _is_cuda and n_share_experts_fusion == 0 and is_power_of_two(correction_bias.shape[0]) + # TODO fuse into the kernel + and expert_location_dispatch_info is None ): - assert expert_location_dispatch_info is None return moe_fused_gate( gating_output, correction_bias,