Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions python/sglang/srt/models/qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,8 +674,10 @@ def get_embed_and_head(self):
return self.model.embed_tokens.weight, self.lm_head.weight

def set_embed_and_head(self, embed, head):
del self.model.embed_tokens.weight
del self.lm_head.weight
if hasattr(self.model.embed_tokens, "weight"):
del self.model.embed_tokens.weight
if hasattr(self.lm_head, "weight"):
del self.lm_head.weight
self.model.embed_tokens.weight = embed
self.lm_head.weight = head
Comment on lines +677 to 682
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

When pipeline parallelism is enabled (pp_size > 1), ranks other than the first rank will have self.model.embed_tokens as a PPMissingLayer, and ranks other than the last rank will have self.lm_head as a PPMissingLayer.

Additionally, when tie_word_embeddings is enabled, self.lm_head is the exact same object as self.model.embed_tokens.

The current implementation:

  1. Unnecessarily assigns weight to PPMissingLayer placeholders on ranks where they are not used, which can prevent the large embed or head tensors from being garbage collected on those ranks.
  2. Performs redundant deletion and setting operations when weights are tied.

We can make this method fully robust and memory-efficient by checking if the modules are instances of PPMissingLayer and ensuring we don't perform redundant operations when self.lm_head is self.model.embed_tokens.

Suggested change
if hasattr(self.model.embed_tokens, "weight"):
del self.model.embed_tokens.weight
if hasattr(self.lm_head, "weight"):
del self.lm_head.weight
self.model.embed_tokens.weight = embed
self.lm_head.weight = head
if not isinstance(self.model.embed_tokens, PPMissingLayer):
if hasattr(self.model.embed_tokens, "weight"):
del self.model.embed_tokens.weight
self.model.embed_tokens.weight = embed
if not isinstance(self.lm_head, PPMissingLayer) and self.lm_head is not self.model.embed_tokens:
if hasattr(self.lm_head, "weight"):
del self.lm_head.weight
self.lm_head.weight = head

torch.cuda.empty_cache()
Expand Down
Loading