Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 32 additions & 45 deletions python/sglang/srt/models/llama4.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,8 @@
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
from sglang.srt.layers.dp_attention import (
dp_gather_partial,
dp_scatter,
get_attention_tp_rank,
get_attention_tp_size,
get_local_attention_dp_size,
Expand Down Expand Up @@ -367,7 +366,10 @@ def __init__(
bias_o_proj=False,
prefix=add_prefix("self_attn", prefix),
)
is_moe_layer = (layer_id + 1) % config.interleave_moe_layer_step == 0
self.config = config
is_moe_layer = self._is_moe_layer(layer_id)
is_previous_moe_layer = self._is_moe_layer(layer_id - 1)

if is_moe_layer:
self.feed_forward = Llama4MoE(
config=config,
Expand All @@ -387,64 +389,49 @@ def __init__(
config.hidden_size, eps=config.rms_norm_eps
)

self.layer_scatter_modes = LayerScatterModes.init_new(
layer_id=layer_id,
num_layers=config.num_hidden_layers,
is_layer_sparse=is_moe_layer,
is_previous_layer_sparse=is_previous_moe_layer,
)

self.layer_communicator = LayerCommunicator(
layer_scatter_modes=self.layer_scatter_modes,
input_layernorm=self.input_layernorm,
post_attention_layernorm=self.post_attention_layernorm,
)

def _is_moe_layer(self, layer_id: int) -> bool:
return (layer_id + 1) % self.config.interleave_moe_layer_step == 0
Comment on lines +405 to +406
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation of _is_moe_layer can produce an unexpected True for layer_id = -1, because (-1 + 1) % N evaluates to 0. This occurs when checking is_previous_moe_layer for the first decoder layer (layer_id = 0).

While the current logic in LayerScatterModes seems to prevent this from causing an an issue, it makes the code less robust and potentially confusing for future modifications. The layer before the first decoder layer (i.e., the embedding layer) should be considered dense, not sparse.

To make this behavior explicit and prevent future bugs, I suggest adding a check for negative layer_id.

    def _is_moe_layer(self, layer_id: int) -> bool:
        if layer_id < 0:
            return False
        return (layer_id + 1) % self.config.interleave_moe_layer_step == 0


def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
if hidden_states.shape[0] == 0:
residual = hidden_states
else:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states, residual = self.layer_communicator.prepare_attn(
hidden_states, residual, forward_batch
)

if hidden_states.shape[0] != 0:
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
forward_batch=forward_batch,
)

# Gather
if get_tensor_model_parallel_world_size() > 1:
# all gather and all reduce
if self.local_dp_size != 1:
if self.attn_tp_rank == 0:
hidden_states += residual
hidden_states, local_hidden_states = (
forward_batch.gathered_buffer,
hidden_states,
)
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
dp_scatter(residual, hidden_states, forward_batch)
hidden_states = self.post_attention_layernorm(hidden_states)
else:
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual
)
else:
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual
)
hidden_states, residual = self.layer_communicator.prepare_mlp(
hidden_states, residual, forward_batch
)

# Fully Connected
hidden_states = self.feed_forward(hidden_states, forward_batch)

# TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
# Scatter
if self.local_dp_size != 1:
# important: forward batch.gathered_buffer is used both after scatter and after gather.
# be careful about this!
hidden_states, global_hidden_states = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
)
dp_scatter(hidden_states, global_hidden_states, forward_batch)
hidden_states, residual = self.layer_communicator.postprocess_layer(
hidden_states, residual, forward_batch
)

return hidden_states, residual

Expand Down
Loading