Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions python/sglang/srt/models/qwen3_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,11 @@ def __init__(

# Override weight_loader for packed checkpoint format.
# Must capture original_loader BEFORE overwriting.
self.in_proj_qkvz.weight.weight_loader = self._make_packed_weight_loader(
self.in_proj_qkvz
self._override_weight_loader(
self.in_proj_qkvz, self._make_packed_weight_loader(self.in_proj_qkvz)
)
self.in_proj_ba.weight.weight_loader = self._make_packed_weight_loader(
self.in_proj_ba
self._override_weight_loader(
self.in_proj_ba, self._make_packed_weight_loader(self.in_proj_ba)
)

# Conv1d weight loader setup
Expand Down Expand Up @@ -216,6 +216,19 @@ def __init__(
dt_bias=self.dt_bias,
)

@staticmethod
def _override_weight_loader(module, new_loader):
"""Override weight_loader on a module's weight parameter.

ModelWeightParameter exposes weight_loader as a read-only property
backed by _weight_loader, while plain parameters store it as a
regular attribute. This helper handles both cases."""
param = module.weight
if hasattr(param, "_weight_loader"):
param._weight_loader = new_loader
else:
param.weight_loader = new_loader

@staticmethod
def _make_packed_weight_loader(module):
"""Create a weight_loader that does contiguous TP slicing for fused
Expand Down
Loading