Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/sglang/srt/layers/moe/fused_moe_triton/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def _load_w13(
)

expert_data = expert_data.narrow(shard_dim, start, shard_size)
expert_data.copy_(loaded_weight)
expert_data.copy_(loaded_weight, non_blocking=True)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Using non_blocking=True is a good optimization for overlapping data transfers. However, this makes the copy asynchronous and requires explicit synchronization before the weights can be safely used.

I see that torch.cuda.synchronize() has been added to qwen3_moe.py, which is correct. But since this change is in a shared layer file, it will affect all models that use FusedMoE. For example, qwen2_moe.py also seems to use this layer, but its load_weights method does not have a synchronization call.

This could lead to race conditions where weights are used before they are fully copied to the device. Please ensure that all models using FusedMoE are updated to include proper synchronization after weight loading to prevent this.


def _load_w2(
self,
Expand Down Expand Up @@ -347,7 +347,7 @@ def _load_w2(
)

# w2, down_proj: Load into only logical weight of w2.
expert_data.copy_(loaded_weight)
expert_data.copy_(loaded_weight, non_blocking=True)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Using non_blocking=True is a good optimization for overlapping data transfers. However, this makes the copy asynchronous and requires explicit synchronization before the weights can be safely used.

I see that torch.cuda.synchronize() has been added to qwen3_moe.py, which is correct. But since this change is in a shared layer file, it will affect all models that use FusedMoE. For example, qwen2_moe.py also seems to use this layer, but its load_weights method does not have a synchronization call.

This could lead to race conditions where weights are used before they are fully copied to the device. Please ensure that all models using FusedMoE are updated to include proper synchronization after weight loading to prevent this.


def _load_single_value(
self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int
Expand Down
3 changes: 3 additions & 0 deletions python/sglang/srt/models/qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,6 +836,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
else:
logger.warning(f"Parameter {name} not found in params_dict")

# Synchronize to ensure all weights are loaded since we loaded them in non-blocking mode
torch.cuda.synchronize()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This synchronization call is necessary for the non-blocking copies to work correctly. However, it will raise an error if the code is run in a CPU-only environment where CUDA is not available.

To prevent this, you should guard this call with a check for CUDA availability. The _is_cuda variable is already defined in this file for this purpose.

Suggested change
torch.cuda.synchronize()
if _is_cuda:
torch.cuda.synchronize()

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to add the synchronize to all moe models, e.g. DeepSeekV3, GLM4MoE.


# TODO mimic deepseek
self.routed_experts_weights_of_layer = {
layer_id: self.model.layers[layer_id].mlp.get_moe_weights()
Expand Down
Loading