-
Notifications
You must be signed in to change notification settings - Fork 5.3k
[3/3] Optimize Slime Update Weights: Load Weight in non-blocking mode #8754
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -283,7 +283,7 @@ def _load_w13( | |
| ) | ||
|
|
||
| expert_data = expert_data.narrow(shard_dim, start, shard_size) | ||
| expert_data.copy_(loaded_weight) | ||
| expert_data.copy_(loaded_weight, non_blocking=True) | ||
|
|
||
| def _load_w2( | ||
| self, | ||
|
|
@@ -347,7 +347,7 @@ def _load_w2( | |
| ) | ||
|
|
||
| # w2, down_proj: Load into only logical weight of w2. | ||
| expert_data.copy_(loaded_weight) | ||
| expert_data.copy_(loaded_weight, non_blocking=True) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using I see that This could lead to race conditions where weights are used before they are fully copied to the device. Please ensure that all models using |
||
|
|
||
| def _load_single_value( | ||
| self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -836,6 +836,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): | |||||||
| else: | ||||||||
| logger.warning(f"Parameter {name} not found in params_dict") | ||||||||
|
|
||||||||
| # Synchronize to ensure all weights are loaded since we loaded them in non-blocking mode | ||||||||
| torch.cuda.synchronize() | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This synchronization call is necessary for the non-blocking copies to work correctly. However, it will raise an error if the code is run in a CPU-only environment where CUDA is not available. To prevent this, you should guard this call with a check for CUDA availability. The
Suggested change
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we need to add the synchronize to all moe models, e.g. DeepSeekV3, GLM4MoE. |
||||||||
|
|
||||||||
| # TODO mimic deepseek | ||||||||
| self.routed_experts_weights_of_layer = { | ||||||||
| layer_id: self.model.layers[layer_id].mlp.get_moe_weights() | ||||||||
|
|
||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using
non_blocking=Trueis a good optimization for overlapping data transfers. However, this makes the copy asynchronous and requires explicit synchronization before the weights can be safely used.I see that
torch.cuda.synchronize()has been added toqwen3_moe.py, which is correct. But since this change is in a shared layer file, it will affect all models that useFusedMoE. For example,qwen2_moe.pyalso seems to use this layer, but itsload_weightsmethod does not have a synchronization call.This could lead to race conditions where weights are used before they are fully copied to the device. Please ensure that all models using
FusedMoEare updated to include proper synchronization after weight loading to prevent this.