From 686b024dbd027cab103d2535bf57c5064443188c Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Thu, 5 Mar 2026 23:08:15 +0000 Subject: [PATCH 1/9] sync eplb works, async hangs Signed-off-by: Sage Moore --- vllm/distributed/eplb/async_worker.py | 16 +-- vllm/distributed/eplb/eplb_state.py | 146 ++++++++++++++--------- vllm/distributed/eplb/policy/abstract.py | 6 +- vllm/distributed/eplb/policy/default.py | 15 +-- 4 files changed, 93 insertions(+), 90 deletions(-) diff --git a/vllm/distributed/eplb/async_worker.py b/vllm/distributed/eplb/async_worker.py index 5dd862f36bc2..eaf01ea15068 100644 --- a/vllm/distributed/eplb/async_worker.py +++ b/vllm/distributed/eplb/async_worker.py @@ -73,11 +73,7 @@ def run_rebalance_experts( # Move the global expert load window to CPU for computation. global_expert_load_window = eplb_stats.global_expert_load_window.cpu() # Compute new expert mappings for the model - ( - new_physical_to_logical_map, - new_logical_to_physical_map, - new_logical_replica_count, - ) = eplb_state.policy.rebalance_experts( + new_physical_to_logical_map = eplb_state.policy.rebalance_experts( global_expert_load_window, eplb_stats.num_replicas, eplb_stats.num_groups, @@ -89,16 +85,6 @@ def run_rebalance_experts( model_state.new_physical_to_logical_map = new_physical_to_logical_map - max_slots = model_state.logical_to_physical_map.shape[-1] - padded_logical = torch.nn.functional.pad( - new_logical_to_physical_map, - (0, max(0, max_slots - new_logical_to_physical_map.shape[-1])), - value=-1, - ).to(model_state.logical_to_physical_map.device) - new_replica = new_logical_replica_count.to(model_state.logical_replica_count.device) - model_state.new_logical_to_physical_map = padded_logical - model_state.new_logical_replica_count = new_replica - async def transfer_run_periodically( state: "EplbState", diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py index b417c2b3256a..b95c087c32d3 100644 --- a/vllm/distributed/eplb/eplb_state.py +++ b/vllm/distributed/eplb/eplb_state.py @@ -235,16 +235,6 @@ class EplbModelState: intermediate variable between `move_to_buffer` and `move_to_workspace`. the size is same as physical_to_logical_map """ - new_logical_to_physical_map: torch.Tensor | None = None - """ - intermediate variable between `move_to_buffer` and `move_to_workspace`. - the size is same as logical_to_physical_map - """ - new_logical_replica_count: torch.Tensor | None = None - """ - intermediate variable between `move_to_buffer` and `move_to_workspace`. - the size is same as logical_replica_count - """ class EplbState: @@ -508,8 +498,6 @@ def add_model( ), cuda_device_index=self.cuda_device_index, new_physical_to_logical_map=None, - new_logical_to_physical_map=None, - new_logical_replica_count=None, ) self.model_states[model_config.compute_hash()] = model_state self.num_valid_physical_experts = model.num_physical_experts @@ -738,17 +726,20 @@ def rearrange( ): if not self.is_async or is_profile: # Get new expert mappings for the model - ( - new_physical_to_logical_map, - new_logical_to_physical_map, - new_logical_replica_count, - ) = self.policy.rebalance_experts( - global_expert_load_window, + new_physical_to_logical_map = self.policy.rebalance_experts( + global_expert_load_window.cpu(), num_replicas, num_groups, num_nodes, num_gpus, - eplb_model_state.physical_to_logical_map, + eplb_model_state.physical_to_logical_map.cpu(), + ) + + num_logical_experts = global_expert_load_window.shape[-1] + (new_logical_to_physical_map, new_logical_replica_count) = ( + compute_logical_maps( + new_physical_to_logical_map, num_logical_experts + ) ) # Update expert weights @@ -847,11 +838,7 @@ def start_async_loop( def _update_layer_mapping_from_new( self, model_state: EplbModelState, layer: int ) -> None: - if ( - model_state.new_physical_to_logical_map is None - or model_state.new_logical_to_physical_map is None - or model_state.new_logical_replica_count is None - ): + if model_state.new_physical_to_logical_map is None: return target_device = model_state.physical_to_logical_map.device @@ -865,19 +852,23 @@ def _update_layer_mapping_from_new( new_physical[layer].to(target_device, non_blocking=True) ) + num_logical_experts = model_state.logical_to_physical_map.shape[1] + new_logical, new_replica_count = compute_logical_maps( + new_physical[layer], num_logical_experts + ) + logical_device = model_state.logical_to_physical_map.device - new_logical = model_state.new_logical_to_physical_map[layer].to(logical_device) max_slots = model_state.logical_to_physical_map.shape[-1] slot_delta = max_slots - new_logical.shape[-1] if slot_delta > 0: new_logical = torch.nn.functional.pad( new_logical, (0, slot_delta), value=-1 ) - model_state.logical_to_physical_map[layer].copy_(new_logical) + model_state.logical_to_physical_map[layer].copy_(new_logical.to(logical_device)) replica_device = model_state.logical_replica_count.device model_state.logical_replica_count[layer].copy_( - model_state.new_logical_replica_count[layer].to(replica_device) + new_replica_count.to(replica_device) ) def _all_ranks_buffer_ready(self, model_state: EplbModelState) -> bool: @@ -989,12 +980,7 @@ def move_to_workspace( def post_eplb(self, model_state: EplbModelState, is_profile: bool = False) -> None: assert model_state.new_physical_to_logical_map is not None - assert model_state.new_logical_to_physical_map is not None - assert model_state.new_logical_replica_count is not None - model_state.new_physical_to_logical_map = None - model_state.new_logical_to_physical_map = None - model_state.new_logical_replica_count = None def _allreduce_list(self, tensor_list: list[torch.Tensor]) -> list[torch.Tensor]: """ @@ -1052,37 +1038,14 @@ def from_mapping( model_config=model_config, ) eplb_state.num_valid_physical_experts = num_valid_physical_experts - num_moe_layers = expanded_physical_to_logical.shape[0] - num_physical_experts = expanded_physical_to_logical.shape[1] eplb_model_state = eplb_state.model_states[model_config.compute_hash()] eplb_model_state.physical_to_logical_map.copy_(expanded_physical_to_logical) - logical_to_physical_map = torch.full( - ( - num_moe_layers, - model.num_logical_experts, - eplb_model_state.logical_to_physical_map.shape[2], - ), - -1, - dtype=torch.int64, - ) - logical_replica_count = torch.zeros( - (num_moe_layers, model.num_logical_experts), - dtype=torch.int64, + (logical_to_physical_map_cpu, logical_replica_count_cpu) = compute_logical_maps( + expanded_physical_to_logical.cpu(), model.num_logical_experts ) - expanded_physical_to_logical_numpy = expanded_physical_to_logical.cpu().numpy() - for layer_idx in range(num_moe_layers): - for phys_idx in range(num_physical_experts): - logical_idx = expanded_physical_to_logical_numpy[layer_idx, phys_idx] - if logical_idx >= 0: - replica_idx = logical_replica_count[layer_idx, logical_idx] - logical_to_physical_map[layer_idx, logical_idx, replica_idx] = ( - phys_idx - ) - logical_replica_count[layer_idx, logical_idx] += 1 - - logical_to_physical_map = logical_to_physical_map.to(device) - logical_replica_count = logical_replica_count.to(device) + logical_to_physical_map = logical_to_physical_map_cpu.to(device) + logical_replica_count = logical_replica_count_cpu.to(device) eplb_model_state.logical_to_physical_map.copy_(logical_to_physical_map) eplb_model_state.logical_replica_count.copy_(logical_replica_count) return eplb_state @@ -1132,3 +1095,68 @@ def _node_count_with_rank_mapping( node_assignment[other_rank] = next_node_id return next_node_id + + +def compute_logical_maps( + physical_to_logical_map: torch.Tensor, + num_logical_experts: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Derive logical_to_physical_map and logical_replica_count from + physical_to_logical_map. + + Args: + physical_to_logical_map: [num_layers, num_physical_experts], logical + expert index for each physical expert slot + num_logical_experts: total number of logical experts + + Returns: + logical_to_physical_map: [num_layers, num_logical_experts, max_replicas], + physical slots per logical expert; -1 where unused + logical_replica_count: [num_layers, num_logical_experts], number of + physical replicas per logical expert + """ + device = physical_to_logical_map.device + assert physical_to_logical_map.device.type == "cpu" + + dtype = physical_to_logical_map.dtype + + # If computing maps for a single layer, unsqueeze a single element layer dimension + per_layer = physical_to_logical_map.dim() == 1 + physical_to_logical_map_view = physical_to_logical_map + if per_layer: + physical_to_logical_map_view = physical_to_logical_map.unsqueeze(0) + assert len(physical_to_logical_map_view.shape) == 2 + num_layers, num_physical = physical_to_logical_map_view.shape + + logical_replica_count = torch.zeros( + num_layers, + num_logical_experts, + dtype=dtype, + device=device, + ) + logical_replica_count.scatter_add_( + 1, physical_to_logical_map_view, torch.ones_like(physical_to_logical_map_view) + ) + + max_replicas = int(logical_replica_count.max().item()) + logical_to_physical_map_out = torch.full( + (num_layers, num_logical_experts, max_replicas), + -1, + dtype=dtype, + device=device, + ) + + running_count = torch.zeros_like(logical_replica_count) + layer_indices = torch.arange(num_layers, device=device) + for phys_idx in range(num_physical): + expert_ids = physical_to_logical_map_view[:, phys_idx] # [num_layers] + replica_idx = running_count[layer_indices, expert_ids] # [num_layers] + logical_to_physical_map_out[layer_indices, expert_ids, replica_idx] = phys_idx + running_count[layer_indices, expert_ids] += 1 + + # If computing maps for a single layer, squeeze out the extra layer dimension + # before returning + if per_layer: + return logical_to_physical_map_out.squeeze(0), logical_replica_count.squeeze(0) + return logical_to_physical_map_out, logical_replica_count diff --git a/vllm/distributed/eplb/policy/abstract.py b/vllm/distributed/eplb/policy/abstract.py index f4435f11bd57..d056468b97b2 100644 --- a/vllm/distributed/eplb/policy/abstract.py +++ b/vllm/distributed/eplb/policy/abstract.py @@ -17,7 +17,7 @@ def rebalance_experts( num_nodes: int, num_ranks: int, old_global_expert_indices: torch.Tensor | None = None, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: """ Entry point for expert-parallelism load balancer. @@ -35,9 +35,5 @@ def rebalance_experts( Returns: physical_to_logical_map: [layers, num_replicas], the expert index of each replica - logical_to_physical_map: [layers, num_logical_experts, X], - the replica indices for each expert - expert_count: [layers, num_logical_experts], number of - physical replicas for each logical expert """ raise NotImplementedError diff --git a/vllm/distributed/eplb/policy/default.py b/vllm/distributed/eplb/policy/default.py index b9cfcae01410..8929a6465c58 100644 --- a/vllm/distributed/eplb/policy/default.py +++ b/vllm/distributed/eplb/policy/default.py @@ -302,7 +302,7 @@ def rebalance_experts( num_nodes: int, num_ranks: int, old_global_expert_indices: torch.Tensor | None = None, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: """ Entry point for expert-parallelism load balancer. @@ -321,12 +321,7 @@ def rebalance_experts( Returns: phy2log: [layers, num_replicas], the expert index of each replica - log2phy: [layers, num_logical_experts, X], - the replica indices for each expert - logcnt: [layers, num_logical_experts], number of - physical replicas for each logical expert """ - device = weight.device num_layers, num_logical_experts = weight.shape weight_np = weight.float().cpu().numpy() old_phy2log_np = ( @@ -355,7 +350,7 @@ def rebalance_experts( # Only apply when the number of GPUs and slots per GPU remain unchanged. # Helps to avoid unnecessary weight copying when experts move # within the same GPU. - if old_global_expert_indices is not None: + if old_phy2log_np is not None: phy2log_np, phy_replicas_idx_np = cls.preserve_intragpu_slots( phy2log_np, phy_replicas_idx_np, num_ranks, old_phy2log_np ) @@ -370,7 +365,5 @@ def rebalance_experts( ) log2phy_np[layer_indices, phy2log_np, phy_replicas_idx_np] = replica_indices - phy2log = torch.from_numpy(phy2log_np).to(device) - log2phy = torch.from_numpy(log2phy_np).to(device) - logcnt = torch.from_numpy(logcnt_np).to(device) - return phy2log, log2phy, logcnt + phy2log = torch.from_numpy(phy2log_np) + return phy2log From 7e7db8f53504b1d549bb6014fcaf1901601e198a Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Thu, 5 Mar 2026 23:22:37 +0000 Subject: [PATCH 2/9] more cleanup Signed-off-by: Sage Moore --- vllm/distributed/eplb/policy/default.py | 23 ++++------------------- 1 file changed, 4 insertions(+), 19 deletions(-) diff --git a/vllm/distributed/eplb/policy/default.py b/vllm/distributed/eplb/policy/default.py index 8929a6465c58..546499b5a20a 100644 --- a/vllm/distributed/eplb/policy/default.py +++ b/vllm/distributed/eplb/policy/default.py @@ -322,7 +322,6 @@ def rebalance_experts( phy2log: [layers, num_replicas], the expert index of each replica """ - num_layers, num_logical_experts = weight.shape weight_np = weight.float().cpu().numpy() old_phy2log_np = ( old_global_expert_indices.cpu().numpy() @@ -332,17 +331,13 @@ def rebalance_experts( if num_groups % num_nodes == 0: # use hierarchical load-balance policy - phy2log_np, phy_replicas_idx_np, logcnt_np = ( - cls.rebalance_experts_hierarchical( - weight_np, num_replicas, num_groups, num_nodes, num_ranks - ) + phy2log_np, phy_replicas_idx_np, _ = cls.rebalance_experts_hierarchical( + weight_np, num_replicas, num_groups, num_nodes, num_ranks ) else: # use global load-balance policy - phy2log_np, phy_replicas_idx_np, logcnt_np = ( - cls.rebalance_experts_hierarchical( - weight_np, num_replicas, 1, 1, num_ranks - ) + phy2log_np, phy_replicas_idx_np, _ = cls.rebalance_experts_hierarchical( + weight_np, num_replicas, 1, 1, num_ranks ) # Optional postprocessing to preserve slots for experts moving @@ -354,16 +349,6 @@ def rebalance_experts( phy2log_np, phy_replicas_idx_np = cls.preserve_intragpu_slots( phy2log_np, phy_replicas_idx_np, num_ranks, old_phy2log_np ) - num_redundant_experts = num_replicas - num_logical_experts - maxlogcnt = num_redundant_experts + 1 - log2phy_np = np.full( - (num_layers, num_logical_experts, maxlogcnt), -1, dtype=np.int64 - ) - layer_indices = np.arange(num_layers)[:, None] - replica_indices = np.tile( - np.arange(num_replicas, dtype=np.int64), (num_layers, 1) - ) - log2phy_np[layer_indices, phy2log_np, phy_replicas_idx_np] = replica_indices phy2log = torch.from_numpy(phy2log_np) return phy2log From d6f54450f261c3d05315341d3cd3741085f62954 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Fri, 6 Mar 2026 15:01:48 +0000 Subject: [PATCH 3/9] padding fix Signed-off-by: Sage Moore --- vllm/distributed/eplb/eplb_state.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py index b95c087c32d3..95b1005b4544 100644 --- a/vllm/distributed/eplb/eplb_state.py +++ b/vllm/distributed/eplb/eplb_state.py @@ -957,7 +957,7 @@ def move_to_workspace( transferred_layer, ) if model_state.layer_to_transfer >= model_state.model.num_moe_layers: - self.post_eplb(model_state, is_profile) + self.post_eplb(model_state) model_state.rebalanced = False model_state.layer_to_transfer = 0 model_state.pending_global_ready_check = False @@ -978,7 +978,7 @@ def move_to_workspace( str(e), ) - def post_eplb(self, model_state: EplbModelState, is_profile: bool = False) -> None: + def post_eplb(self, model_state: EplbModelState) -> None: assert model_state.new_physical_to_logical_map is not None model_state.new_physical_to_logical_map = None @@ -1044,10 +1044,22 @@ def from_mapping( (logical_to_physical_map_cpu, logical_replica_count_cpu) = compute_logical_maps( expanded_physical_to_logical.cpu(), model.num_logical_experts ) - logical_to_physical_map = logical_to_physical_map_cpu.to(device) + + max_num_replicas = eplb_model_state.logical_to_physical_map.shape[-1] + num_replicas = logical_to_physical_map_cpu.shape[-1] + logical_to_physical_map = torch.nn.functional.pad( + logical_to_physical_map_cpu, + ( + 0, + max_num_replicas - num_replicas, + ), + value=-1, + ).to(device) logical_replica_count = logical_replica_count_cpu.to(device) + eplb_model_state.logical_to_physical_map.copy_(logical_to_physical_map) eplb_model_state.logical_replica_count.copy_(logical_replica_count) + return eplb_state From 765b836b204f4a004b5b01f583de1b36c55f9b2b Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Fri, 6 Mar 2026 15:14:17 +0000 Subject: [PATCH 4/9] rebalance cleanup Signed-off-by: Sage Moore --- vllm/distributed/eplb/policy/default.py | 37 ++++++------------------- 1 file changed, 9 insertions(+), 28 deletions(-) diff --git a/vllm/distributed/eplb/policy/default.py b/vllm/distributed/eplb/policy/default.py index 546499b5a20a..cb0941d459db 100644 --- a/vllm/distributed/eplb/policy/default.py +++ b/vllm/distributed/eplb/policy/default.py @@ -111,7 +111,7 @@ def rebalance_experts_hierarchical( num_groups: int, num_nodes: int, num_gpus: int, - ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + ) -> np.ndarray: """ Parameters: weight: [num_moe_layers, num_logical_experts] @@ -124,10 +124,6 @@ def rebalance_experts_hierarchical( Returns: phy2log: [layers, num_replicas], the expert index of each replica - pphy_replicas_idx: [layers, num_logical_experts, X], - the replica indices for each expert - logcnt: [layers, num_logical_experts], number of - physical replicas for each logical expert """ num_layers, num_logical_experts = weight.shape assert num_logical_experts % num_groups == 0 @@ -193,22 +189,15 @@ def inverse(perm: np.ndarray) -> np.ndarray: ).reshape(num_layers, -1) # Map node-local logical indices back to global logical expert ids. pphy2log = np.take_along_axis(mlog2log, pphy2mlog, axis=1) - # Reorder replica ranks to the post-packing physical ordering. - pphy_replicas_idx = np.take_along_axis(replicas_idx, pphy2phy, axis=1).reshape( - num_layers, -1 - ) - # Convert replica counts back to the original logical ordering. - logcnt = np.take_along_axis(mlogcnt.reshape(num_layers, -1), log2mlog, axis=1) - return pphy2log, pphy_replicas_idx, logcnt + return pphy2log @classmethod def preserve_intragpu_slots( cls, phy2log: np.ndarray, - phy_replicas_idx: np.ndarray, num_ranks: int, old_phy2log: np.ndarray, - ) -> tuple[np.ndarray, np.ndarray]: + ) -> np.ndarray: """ Reorder the new mapping per GPU so that experts that remain on the same GPU keep their previous slot positions when possible. Incoming experts to that GPU @@ -218,14 +207,13 @@ def preserve_intragpu_slots( """ num_phy_experts = phy2log.shape[1] if num_ranks <= 0 or num_phy_experts % num_ranks != 0: - return phy2log, phy_replicas_idx + return phy2log # Move to CPU and convert to NumPy for processing slots_per_gpu = num_phy_experts // num_ranks num_layers = phy2log.shape[0] post_phy2log = phy2log.copy() - post_phy_replicas_idx = phy_replicas_idx.copy() for gpu_idx in range(num_ranks): start = gpu_idx * slots_per_gpu @@ -233,7 +221,6 @@ def preserve_intragpu_slots( # Experts across all layers for this GPU old_local = old_phy2log[:, start:end] # [layers, slots] new_local = phy2log[:, start:end] # [layers, slots] - new_ridx = phy_replicas_idx[:, start:end] # [layers, slots] used_new_indices = np.zeros((num_layers, slots_per_gpu), dtype=bool) preserved_positions = np.zeros((num_layers, slots_per_gpu), dtype=bool) @@ -253,9 +240,6 @@ def preserve_intragpu_slots( post_phy2log[layer_indices, start + slot_idx] = new_local[ layer_indices, matched_new_positions ] - post_phy_replicas_idx[layer_indices, start + slot_idx] = new_ridx[ - layer_indices, matched_new_positions - ] used_new_indices[layer_indices, matched_new_positions] = True preserved_positions[layer_indices, slot_idx] = True @@ -287,11 +271,8 @@ def preserve_intragpu_slots( post_phy2log[layer_idx, start + dst_pos] = new_local[ layer_idx, src_pos ] - post_phy_replicas_idx[layer_idx, start + dst_pos] = new_ridx[ - layer_idx, src_pos - ] - return post_phy2log, post_phy_replicas_idx + return post_phy2log @classmethod def rebalance_experts( @@ -331,12 +312,12 @@ def rebalance_experts( if num_groups % num_nodes == 0: # use hierarchical load-balance policy - phy2log_np, phy_replicas_idx_np, _ = cls.rebalance_experts_hierarchical( + phy2log_np = cls.rebalance_experts_hierarchical( weight_np, num_replicas, num_groups, num_nodes, num_ranks ) else: # use global load-balance policy - phy2log_np, phy_replicas_idx_np, _ = cls.rebalance_experts_hierarchical( + phy2log_np = cls.rebalance_experts_hierarchical( weight_np, num_replicas, 1, 1, num_ranks ) @@ -346,8 +327,8 @@ def rebalance_experts( # Helps to avoid unnecessary weight copying when experts move # within the same GPU. if old_phy2log_np is not None: - phy2log_np, phy_replicas_idx_np = cls.preserve_intragpu_slots( - phy2log_np, phy_replicas_idx_np, num_ranks, old_phy2log_np + phy2log_np = cls.preserve_intragpu_slots( + phy2log_np, num_ranks, old_phy2log_np ) phy2log = torch.from_numpy(phy2log_np) From 42f4b4b068c876883bb8bef49eb2a05b768d75ab Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Fri, 6 Mar 2026 15:18:40 +0000 Subject: [PATCH 5/9] more rebalance cleanup Signed-off-by: Sage Moore --- vllm/distributed/eplb/policy/default.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/vllm/distributed/eplb/policy/default.py b/vllm/distributed/eplb/policy/default.py index cb0941d459db..317097c023c5 100644 --- a/vllm/distributed/eplb/policy/default.py +++ b/vllm/distributed/eplb/policy/default.py @@ -75,7 +75,7 @@ def balanced_packing( @classmethod def replicate_experts( cls, weight: np.ndarray, num_phy: int - ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + ) -> tuple[np.ndarray, np.ndarray]: """ Replicate `num_log` experts to `num_phy` replicas, such that the maximum load of all replicas is minimized. @@ -86,22 +86,19 @@ def replicate_experts( Returns: phy2log: [X, num_phy], logical expert id of each physical expert - replica_idx: [X, num_phy], the index of the replica for each logical expert logcnt: [X, num_log], number of replicas for each logical expert """ n, num_log = weight.shape num_redundant = num_phy - num_log assert num_redundant >= 0 phy2log = np.tile(np.arange(num_phy, dtype=np.int64), (n, 1)) - replica_idx = np.zeros((n, num_phy), dtype=np.int64) logcnt = np.ones((n, num_log), dtype=np.int64) arangen = np.arange(n, dtype=np.int64) for i in range(num_log, num_phy): redundant_indices = np.argmax(weight / logcnt, axis=-1) phy2log[:, i] = redundant_indices - replica_idx[:, i] = logcnt[arangen, redundant_indices] logcnt[arangen, redundant_indices] += 1 - return phy2log, replica_idx, logcnt + return phy2log, logcnt @classmethod def rebalance_experts_hierarchical( @@ -163,7 +160,7 @@ def inverse(perm: np.ndarray) -> np.ndarray: tokens_per_mlog = np.take_along_axis(weight, mlog2log, axis=1).reshape( -1, num_logical_experts // num_nodes ) - phy2mlog, replicas_idx, mlogcnt = cls.replicate_experts( + phy2mlog, mlogcnt = cls.replicate_experts( tokens_per_mlog, num_physical_experts // num_nodes ) From 83e721931be5f7fe819eee6bae7c4cb9a40930ac Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Fri, 6 Mar 2026 17:44:41 +0000 Subject: [PATCH 6/9] fix test_eplb_algo unit test Signed-off-by: Sage Moore --- tests/distributed/test_eplb_algo.py | 45 +++++++++++++++++------------ 1 file changed, 27 insertions(+), 18 deletions(-) diff --git a/tests/distributed/test_eplb_algo.py b/tests/distributed/test_eplb_algo.py index 6fe44fc21801..c482e50f1a4d 100644 --- a/tests/distributed/test_eplb_algo.py +++ b/tests/distributed/test_eplb_algo.py @@ -5,6 +5,7 @@ import pytest import torch +from vllm.distributed.eplb.eplb_state import compute_logical_maps from vllm.distributed.eplb.policy.default import DefaultEplbPolicy @@ -24,9 +25,10 @@ def test_basic_rebalance(): num_nodes = 2 num_gpus = 8 - phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts( + phy2log = DefaultEplbPolicy.rebalance_experts( weight, num_replicas, num_groups, num_nodes, num_gpus ) + log2phy, logcnt = compute_logical_maps(phy2log, weight.shape[-1]) # Verify output shapes assert phy2log.shape == ( @@ -78,9 +80,10 @@ def test_single_gpu_case(): num_nodes = 1 num_gpus = 1 - phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts( + phy2log = DefaultEplbPolicy.rebalance_experts( weight, num_replicas, num_groups, num_nodes, num_gpus ) + log2phy, logcnt = compute_logical_maps(phy2log, weight.shape[-1]) # Verify shapes assert phy2log.shape == (1, 4) @@ -100,9 +103,10 @@ def test_equal_weights(): num_nodes = 2 num_gpus = 4 - phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts( + phy2log = DefaultEplbPolicy.rebalance_experts( weight, num_replicas, num_groups, num_nodes, num_gpus ) + _, logcnt = compute_logical_maps(phy2log, weight.shape[-1]) # Verify shapes assert phy2log.shape == (1, 8) @@ -123,9 +127,10 @@ def test_extreme_weight_imbalance(): num_nodes = 2 num_gpus = 4 - phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts( + phy2log = DefaultEplbPolicy.rebalance_experts( weight, num_replicas, num_groups, num_nodes, num_gpus ) + _, logcnt = compute_logical_maps(phy2log, weight.shape[-1]) # Verify shapes assert phy2log.shape == (1, 12) @@ -151,9 +156,10 @@ def test_multiple_layers(): num_nodes = 2 num_gpus = 4 - phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts( + phy2log = DefaultEplbPolicy.rebalance_experts( weight, num_replicas, num_groups, num_nodes, num_gpus ) + _, logcnt = compute_logical_maps(phy2log, weight.shape[-1]) # Verify shapes assert phy2log.shape == (3, 8) @@ -176,7 +182,8 @@ def test_parameter_validation(): # Test non-divisible case - this should handle normally without throwing # errors because the function will fall back to global load balancing # strategy - phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(weight, 8, 3, 2, 4) + phy2log = DefaultEplbPolicy.rebalance_experts(weight, 8, 3, 2, 4) + _, logcnt = compute_logical_maps(phy2log, weight.shape[-1]) assert phy2log.shape == (1, 8) assert logcnt.shape == (1, 4) @@ -198,9 +205,10 @@ def test_small_scale_hierarchical(): num_nodes = 2 # 2 nodes num_gpus = 4 # 4 GPUs - phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts( + phy2log = DefaultEplbPolicy.rebalance_experts( weight, num_replicas, num_groups, num_nodes, num_gpus ) + _, logcnt = compute_logical_maps(phy2log, weight.shape[-1]) # Verify basic constraints assert phy2log.shape == (1, 12) @@ -225,9 +233,10 @@ def test_global_load_balance_fallback(): num_nodes = 2 num_gpus = 4 - phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts( + phy2log = DefaultEplbPolicy.rebalance_experts( weight, num_replicas, num_groups, num_nodes, num_gpus ) + _, logcnt = compute_logical_maps(phy2log, weight.shape[-1]) # Should work normally, just using global load balancing strategy assert phy2log.shape == (1, 8) @@ -247,9 +256,10 @@ def test_device_compatibility(device): num_nodes = 1 num_gpus = 2 - phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts( + phy2log = DefaultEplbPolicy.rebalance_experts( weight, num_replicas, num_groups, num_nodes, num_gpus ) + _, logcnt = compute_logical_maps(phy2log, weight.shape[-1]) # Function will convert to CPU internally, but should handle different # device inputs normally @@ -264,9 +274,8 @@ def test_additional_cases(): weight1 = torch.tensor( [[50, 100, 75, 120, 90, 60, 80, 110, 40, 70, 95, 85, 65, 55, 45, 35]] ) - phy2log1, log2phy1, logcnt1 = DefaultEplbPolicy.rebalance_experts( - weight1, 24, 8, 4, 8 - ) + phy2log1 = DefaultEplbPolicy.rebalance_experts(weight1, 24, 8, 4, 8) + _, logcnt1 = compute_logical_maps(phy2log1, weight1.shape[-1]) assert phy2log1.shape == (1, 24) assert logcnt1.shape == (1, 16) @@ -279,9 +288,8 @@ def test_additional_cases(): [12, 25, 50, 100, 150, 200], # Increasing weights ] ) - phy2log2, log2phy2, logcnt2 = DefaultEplbPolicy.rebalance_experts( - weight2, 10, 3, 1, 2 - ) + phy2log2 = DefaultEplbPolicy.rebalance_experts(weight2, 10, 3, 1, 2) + _, logcnt2 = compute_logical_maps(phy2log2, weight2.shape[-1]) assert phy2log2.shape == (2, 10) assert logcnt2.shape == (2, 6) @@ -305,7 +313,7 @@ def test_additional_cases(): num_nodes = 2 num_gpus = 8 - phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts( + phy2log = DefaultEplbPolicy.rebalance_experts( weight, num_replicas, num_groups, num_nodes, num_gpus ) print(phy2log) @@ -434,9 +442,10 @@ def test_preserve_intragpu_slots( """Experts that stay on a GPU keep their old slots; incoming not lost.""" phy_replicas_idx = _make_phy_replicas_idx_from_phy2log(new_phy2log) - post_phy2log, post_phy_replicas_idx = DefaultEplbPolicy.preserve_intragpu_slots( - new_phy2log, phy_replicas_idx, num_ranks, old_phy2log + post_phy2log = DefaultEplbPolicy.preserve_intragpu_slots( + new_phy2log, num_ranks, old_phy2log ) + post_phy_replicas_idx = _make_phy_replicas_idx_from_phy2log(post_phy2log) # Shapes preserved assert post_phy2log.shape == new_phy2log.shape From faba44f67ea17719a4746c241b5cea1467e6322f Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Mon, 9 Mar 2026 15:53:03 +0000 Subject: [PATCH 7/9] added support for -1 logical expert ids Signed-off-by: Sage Moore --- tests/distributed/test_eplb_algo.py | 48 +++++++++++++++++++++++++++++ vllm/distributed/eplb/eplb_state.py | 23 +++++++++++--- 2 files changed, 66 insertions(+), 5 deletions(-) diff --git a/tests/distributed/test_eplb_algo.py b/tests/distributed/test_eplb_algo.py index c482e50f1a4d..d36a4c5bb51b 100644 --- a/tests/distributed/test_eplb_algo.py +++ b/tests/distributed/test_eplb_algo.py @@ -300,6 +300,54 @@ def test_additional_cases(): assert logcnt2[layer, max_weight_idx] >= 2 +def test_compute_logical_maps_with_negative_indices(): + """ + Test that compute_logical_maps correctly handles physical slots containing + -1 (unused slots). Without the >= 0 guard, -1 would be treated as a valid + index via Python's negative indexing and corrupt the last expert's counts. + """ + # 2 layers, 6 physical slots, 4 logical experts. + # Slots 2 and 5 are unused (-1). + phy2log = torch.tensor( + [ + [0, 1, -1, 2, 3, -1], + [3, -1, 2, 1, 0, -1], + ] + ) + num_logical_experts = 4 + + log2phy, logcnt = compute_logical_maps(phy2log, num_logical_experts) + + # Shapes + assert logcnt.shape == (2, 4) + assert log2phy.shape[0] == 2 + assert log2phy.shape[1] == 4 + + # Each logical expert appears exactly once per layer + expected_logcnt = torch.ones(2, 4, dtype=phy2log.dtype) + assert torch.all(logcnt == expected_logcnt), ( + f"Expected all replica counts == 1, got {logcnt}" + ) + + # -1 slots must not inflate any expert's count + assert torch.all(logcnt >= 0), "No expert should have a negative count" + assert torch.all(logcnt <= 1), ( + "No expert should have more than 1 replica (no duplicates in input)" + ) + + # Unused slots (-1) should not appear in log2phy + assert torch.all(log2phy >= 0), ( + "log2phy should only contain valid physical indices, not -1 sentinel" + ) + + # Verify the actual physical slot assignments are correct (layer 0) + # Expert 0 -> slot 0, Expert 1 -> slot 1, Expert 2 -> slot 3, Expert 3 -> slot 4 + assert log2phy[0, 0, 0] == 0 + assert log2phy[0, 1, 0] == 1 + assert log2phy[0, 2, 0] == 3 + assert log2phy[0, 3, 0] == 4 + + if __name__ == "__main__": weight = torch.tensor( [ diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py index 95b1005b4544..d01ea69d82c5 100644 --- a/vllm/distributed/eplb/eplb_state.py +++ b/vllm/distributed/eplb/eplb_state.py @@ -1141,6 +1141,7 @@ def compute_logical_maps( assert len(physical_to_logical_map_view.shape) == 2 num_layers, num_physical = physical_to_logical_map_view.shape + valid_mask = physical_to_logical_map_view >= 0 logical_replica_count = torch.zeros( num_layers, num_logical_experts, @@ -1148,7 +1149,9 @@ def compute_logical_maps( device=device, ) logical_replica_count.scatter_add_( - 1, physical_to_logical_map_view, torch.ones_like(physical_to_logical_map_view) + 1, + physical_to_logical_map_view.clamp(min=0), + valid_mask.to(dtype), ) max_replicas = int(logical_replica_count.max().item()) @@ -1162,10 +1165,20 @@ def compute_logical_maps( running_count = torch.zeros_like(logical_replica_count) layer_indices = torch.arange(num_layers, device=device) for phys_idx in range(num_physical): - expert_ids = physical_to_logical_map_view[:, phys_idx] # [num_layers] - replica_idx = running_count[layer_indices, expert_ids] # [num_layers] - logical_to_physical_map_out[layer_indices, expert_ids, replica_idx] = phys_idx - running_count[layer_indices, expert_ids] += 1 + # Logical expert at physical slot phys_idx for each layer + logical_expert_ids = physical_to_logical_map_view[:, phys_idx] # [num_layers] + + # Only consider "valid" experts. I.E not -1 + valid_expert_mask = logical_expert_ids >= 0 + if not valid_expert_mask.any(): + continue + valid_layers = layer_indices[valid_expert_mask] + valid_experts = logical_expert_ids[valid_expert_mask] + + # Use the current running count as the replica index, then increment it. + replica_idx = running_count[valid_layers, valid_experts] + logical_to_physical_map_out[valid_layers, valid_experts, replica_idx] = phys_idx + running_count[valid_layers, valid_experts] += 1 # If computing maps for a single layer, squeeze out the extra layer dimension # before returning From 933e70cc5c86275034c1cff09d9ae62bb7476609 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Mon, 9 Mar 2026 16:04:06 +0000 Subject: [PATCH 8/9] comments Signed-off-by: Sage Moore --- vllm/distributed/eplb/eplb_state.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py index d01ea69d82c5..22da146a3e03 100644 --- a/vllm/distributed/eplb/eplb_state.py +++ b/vllm/distributed/eplb/eplb_state.py @@ -1168,7 +1168,8 @@ def compute_logical_maps( # Logical expert at physical slot phys_idx for each layer logical_expert_ids = physical_to_logical_map_view[:, phys_idx] # [num_layers] - # Only consider "valid" experts. I.E not -1 + # Scale up will set the logical expert ids to -1 for all new physical experts. + # Only consider "valid" experts when setting up the logical_to_physical map. valid_expert_mask = logical_expert_ids >= 0 if not valid_expert_mask.any(): continue From a0534f33638b4076c303ce7a29697159cf8fb8c1 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Wed, 11 Mar 2026 15:39:05 +0000 Subject: [PATCH 9/9] test cleanup Signed-off-by: Sage Moore --- tests/distributed/test_eplb_algo.py | 26 +++++++------------------- 1 file changed, 7 insertions(+), 19 deletions(-) diff --git a/tests/distributed/test_eplb_algo.py b/tests/distributed/test_eplb_algo.py index d36a4c5bb51b..721132d15b1d 100644 --- a/tests/distributed/test_eplb_algo.py +++ b/tests/distributed/test_eplb_algo.py @@ -303,8 +303,7 @@ def test_additional_cases(): def test_compute_logical_maps_with_negative_indices(): """ Test that compute_logical_maps correctly handles physical slots containing - -1 (unused slots). Without the >= 0 guard, -1 would be treated as a valid - index via Python's negative indexing and corrupt the last expert's counts. + -1 (unused slots). """ # 2 layers, 6 physical slots, 4 logical experts. # Slots 2 and 5 are unused (-1). @@ -314,34 +313,23 @@ def test_compute_logical_maps_with_negative_indices(): [3, -1, 2, 1, 0, -1], ] ) + num_layers = 2 num_logical_experts = 4 log2phy, logcnt = compute_logical_maps(phy2log, num_logical_experts) - # Shapes - assert logcnt.shape == (2, 4) - assert log2phy.shape[0] == 2 - assert log2phy.shape[1] == 4 + assert logcnt.shape == (num_layers, num_logical_experts) + assert log2phy.shape == (num_layers, num_logical_experts, 1) - # Each logical expert appears exactly once per layer - expected_logcnt = torch.ones(2, 4, dtype=phy2log.dtype) + expected_logcnt = torch.ones(num_layers, num_logical_experts, dtype=phy2log.dtype) assert torch.all(logcnt == expected_logcnt), ( - f"Expected all replica counts == 1, got {logcnt}" - ) - - # -1 slots must not inflate any expert's count - assert torch.all(logcnt >= 0), "No expert should have a negative count" - assert torch.all(logcnt <= 1), ( - "No expert should have more than 1 replica (no duplicates in input)" + f"Expected that all replica counts == 1, got {logcnt}" ) - # Unused slots (-1) should not appear in log2phy assert torch.all(log2phy >= 0), ( - "log2phy should only contain valid physical indices, not -1 sentinel" + "log2phy should only contain valid physical indices, not -1" ) - # Verify the actual physical slot assignments are correct (layer 0) - # Expert 0 -> slot 0, Expert 1 -> slot 1, Expert 2 -> slot 3, Expert 3 -> slot 4 assert log2phy[0, 0, 0] == 0 assert log2phy[0, 1, 0] == 1 assert log2phy[0, 2, 0] == 3