Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 21 additions & 13 deletions python/sglang/srt/managers/expert_location.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ class ExpertLocationMetadata:
physical_to_logical_map: torch.Tensor # (layers, num_physical_experts)
logical_to_all_physical_map: torch.Tensor # (layers, num_logical_experts, X)
logical_to_all_physical_map_num_valid: torch.Tensor # (layers, num_logical_experts)
logical_to_rank_dispatch_physical_map: torch.Tensor # (layers, num_logical_experts)
# (layers, num_logical_experts)
logical_to_rank_dispatch_physical_map: Optional[torch.Tensor]

# -------------------------------- properties ------------------------------------

Expand Down Expand Up @@ -70,11 +71,8 @@ def __post_init__(self):
num_layers_2, num_logical_experts_1 = (
self.logical_to_all_physical_map_num_valid.shape
)
num_layers_3, num_logical_experts_2 = (
self.logical_to_rank_dispatch_physical_map.shape
)
assert num_layers_0 == num_layers_1 == num_layers_2 == num_layers_3
assert num_logical_experts_0 == num_logical_experts_1 == num_logical_experts_2
assert num_layers_0 == num_layers_1 == num_layers_2
assert num_logical_experts_0 == num_logical_experts_1
assert num_physical_experts_0 == num_physical_experts_1

# -------------------------------- construction ------------------------------------
Expand Down Expand Up @@ -117,6 +115,7 @@ def init_by_mapping(
)

return ExpertLocationMetadata._init_raw(
server_args=server_args,
ep_size=common["ep_size"],
physical_to_logical_map=physical_to_logical_map,
logical_to_all_physical_map=logical_to_all_physical_map,
Expand Down Expand Up @@ -154,6 +153,7 @@ def init_by_eplb(
)

return ExpertLocationMetadata._init_raw(
server_args=server_args,
ep_size=common["ep_size"],
physical_to_logical_map=physical_to_logical_map.to(server_args.device),
logical_to_all_physical_map=logical_to_all_physical_map.to(
Expand Down Expand Up @@ -184,6 +184,7 @@ def _init_common(server_args: ServerArgs, model_config: ModelConfig):

@staticmethod
def _init_raw(
server_args: ServerArgs,
ep_size: int,
physical_to_logical_map: torch.Tensor,
logical_to_all_physical_map: torch.Tensor,
Expand All @@ -204,12 +205,16 @@ def _init_raw(
physical_to_logical_map=physical_to_logical_map,
logical_to_all_physical_map=logical_to_all_physical_map_padded,
logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid,
logical_to_rank_dispatch_physical_map=compute_logical_to_rank_dispatch_physical_map(
logical_to_all_physical_map=logical_to_all_physical_map,
num_gpus=ep_size,
num_physical_experts=num_physical_experts,
# TODO improve when we have real EP rank
ep_rank=torch.distributed.get_rank() % ep_size,
logical_to_rank_dispatch_physical_map=(
compute_logical_to_rank_dispatch_physical_map(
logical_to_all_physical_map=logical_to_all_physical_map,
num_gpus=ep_size,
num_physical_experts=num_physical_experts,
# TODO improve when we have real EP rank
ep_rank=torch.distributed.get_rank() % ep_size,
)
if server_args.ep_dispatch_algorithm == "static"
else None
),
)

Expand All @@ -230,8 +235,11 @@ def update(
"logical_to_all_physical_map_num_valid",
"logical_to_rank_dispatch_physical_map",
]:
src = getattr(other, field)
dst = getattr(self, field)
dst[...] = getattr(other, field)
assert (src is not None) == (dst is not None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The assertion (src is not None) == (dst is not None) ensures that self and other have the same optional status for the field being updated (e.g., logical_to_rank_dispatch_physical_map). This implies they were likely configured with the same ep_dispatch_algorithm.

Could you clarify if it's an expected invariant that the update method is only called between ExpertLocationMetadata instances that have a consistent configuration regarding the presence of this map? If there are scenarios where self and other might have different configurations (e.g., one static, one dynamic), this assertion would prevent the update. Understanding the intended state transitions here would be helpful.

if dst is not None:
dst[...] = src

# -------------------------------- usage ------------------------------------

Expand Down
13 changes: 9 additions & 4 deletions python/sglang/srt/managers/expert_location_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
class ExpertLocationDispatchInfo:
ep_dispatch_algorithm: Literal["static", "random"]
# (num_logical_experts,)
partial_logical_to_rank_dispatch_physical_map: torch.Tensor
partial_logical_to_rank_dispatch_physical_map: Optional[torch.Tensor]
# (num_logical_experts, X)
partial_logical_to_all_physical_map: torch.Tensor
# (num_logical_experts,)
Expand All @@ -42,9 +42,14 @@ def init_new(cls, layer_id: int):

return cls(
ep_dispatch_algorithm=ep_dispatch_algorithm,
partial_logical_to_rank_dispatch_physical_map=expert_location_metadata.logical_to_rank_dispatch_physical_map[
layer_id, :
],
partial_logical_to_rank_dispatch_physical_map=(
expert_location_metadata.logical_to_rank_dispatch_physical_map[
layer_id, :
]
if expert_location_metadata.logical_to_rank_dispatch_physical_map
is not None
else None
),
partial_logical_to_all_physical_map=expert_location_metadata.logical_to_all_physical_map[
layer_id, :
],
Expand Down
Loading