Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 30 additions & 5 deletions python/sglang/srt/eplb/expert_location.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ def _init_raw(
logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid,
logical_to_rank_dispatch_physical_map=(
compute_logical_to_rank_dispatch_physical_map(
server_args=server_args,
logical_to_all_physical_map=logical_to_all_physical_map,
num_gpus=ep_size,
num_physical_experts=num_physical_experts,
Expand Down Expand Up @@ -340,6 +341,7 @@ def _pad_nested_array(arr, pad_value):

# TODO optimize performance (rewrite and/or run in separate process with overlap)
def compute_logical_to_rank_dispatch_physical_map(
server_args: ServerArgs,
logical_to_all_physical_map: torch.Tensor,
num_gpus: int,
num_physical_experts: int,
Expand All @@ -348,7 +350,9 @@ def compute_logical_to_rank_dispatch_physical_map(
):
r = random.Random(seed)

num_local_physical_experts = num_physical_experts // num_gpus
num_local_gpu_physical_experts = num_physical_experts // num_gpus
num_gpus_per_node = server_args.ep_size // server_args.nnodes
num_local_node_physical_experts = num_local_gpu_physical_experts * num_gpus_per_node
num_layers, num_logical_experts, _ = logical_to_all_physical_map.shape
dtype = logical_to_all_physical_map.dtype

Expand All @@ -372,13 +376,28 @@ def compute_logical_to_rank_dispatch_physical_map(
physical_expert_id
for physical_expert_id in candidate_physical_expert_ids
if _compute_gpu_id_of_physical_expert(
physical_expert_id, num_local_physical_experts
physical_expert_id, num_local_gpu_physical_experts
)
== gpu_id
]
if len(same_gpu_physical_expert_ids) > 0:
# 1. Prefer same-GPU experts
output_partial[gpu_id] = same_gpu_physical_expert_ids[0]

else:
# 2. Otherwise, prefer same-node experts
node_id = gpu_id // num_gpus_per_node
same_node_physical_expert_ids = [
physical_expert_id
for physical_expert_id in candidate_physical_expert_ids
if _compute_node_id_of_physical_expert(
physical_expert_id, num_local_node_physical_experts
)
== node_id
]
if len(same_node_physical_expert_ids) > 0:
output_partial[gpu_id] = same_node_physical_expert_ids[0]

# 3. Fill remaining slots with fair random choices
num_remain = torch.sum(output_partial == -1).item()
output_partial[output_partial == -1] = torch.tensor(
_fair_choices(candidate_physical_expert_ids, k=num_remain, r=r),
Expand All @@ -404,9 +423,15 @@ def _logical_to_all_physical_raw(


def _compute_gpu_id_of_physical_expert(
physical_expert_id: int, num_local_physical_experts: int
physical_expert_id: int, num_local_gpu_physical_experts: int
) -> int:
return physical_expert_id // num_local_gpu_physical_experts


def _compute_node_id_of_physical_expert(
physical_expert_id: int, num_local_host_physical_experts: int
) -> int:
return physical_expert_id // num_local_physical_experts
return physical_expert_id // num_local_host_physical_experts


def _fair_choices(arr: List, k: int, r: random.Random) -> List:
Expand Down
Loading