Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 35 additions & 6 deletions python/sglang/srt/models/qwen3_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,7 @@ def __init__(
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
quant_config=quant_config,
prefix=f"{prefix}.attn",
)

Expand Down Expand Up @@ -921,6 +922,15 @@ class HybridLayerType(enum.Enum):
class Qwen3NextForCausalLM(nn.Module):
fall_back_to_pt_during_load = False

# Map fused module names to their checkpoint (unfused) counterparts.
# This is needed so the quantization exclusion logic can match
# checkpoint-style names (e.g. "q_proj") against the fused sglang
# module names (e.g. "qkv_proj").
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to do it not inside of model?
You do it in Qwen3-Next and Qwen3 #18189. What about other models? is it framework specific or model specific?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's because qkv and o proj are NVFP4 in these two recipes

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I don't get your point. Can we do it not in the model? I believe it should in the quantization part python/sglang/srt/layers/quantization/

def __init__(
self,
config: Qwen3NextConfig,
Expand All @@ -931,6 +941,14 @@ def __init__(
self.config = config
self.pp_group = get_pp_group()
assert self.pp_group.is_first_rank and self.pp_group.is_last_rank

# The quant config's packed_modules_mapping may be None if it wasn't
# in the checkpoint config. The base class (QuantizationConfig) intends
# for models to set this. We need it so is_layer_skipped can unfuse
# "qkv_proj" into ["q_proj","k_proj","v_proj"] when checking exclusions.
if quant_config is not None and hasattr(quant_config, "packed_modules_mapping"):
quant_config.packed_modules_mapping = self.packed_modules_mapping

self.quant_config = quant_config
self.model = Qwen3NextModel(
config, quant_config, prefix=add_prefix("model", prefix)
Expand Down Expand Up @@ -1052,6 +1070,14 @@ def load_weights(
if ".self_attn." in name:
name = name.replace(".self_attn", "")

# Remap modelopt FP8 KV cache scale names:
# checkpoint: k_proj.k_scale / v_proj.v_scale
# model: attn.k_scale / attn.v_scale
if name.endswith(".k_proj.k_scale"):
name = name.replace(".k_proj.k_scale", ".attn.k_scale")
elif name.endswith(".v_proj.v_scale"):
name = name.replace(".v_proj.v_scale", ".attn.v_scale")

for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
Expand All @@ -1060,15 +1086,16 @@ def load_weights(
if "mlp.experts" in name:
continue

name = name.replace(weight_name, param_name)
replaced_name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
if replaced_name.endswith(".bias") and replaced_name not in params_dict:
continue
# Skip layers on other devices.
# if is_pp_missing_parameter(name, self):
# continue
if name not in params_dict:
if replaced_name not in params_dict:
continue
name = replaced_name
param = params_dict[name]
weight_loader = getattr(param, "weight_loader")
weight_loader(param, loaded_weight, shard_id)
Expand All @@ -1078,15 +1105,17 @@ def load_weights(
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
replaced_name = name.replace(weight_name, param_name)
# Skip layers on other devices.
# if is_pp_missing_parameter(name, self):
# continue
# Skip loading extra bias for GPTQ models.
if (
name.endswith(".bias") or name.endswith("_bias")
) and name not in params_dict:
replaced_name.endswith(".bias")
or replaced_name.endswith("_bias")
) and replaced_name not in params_dict:
continue
name = replaced_name
param = params_dict[name]

weight_loader = getattr(param, "weight_loader")
Expand Down
Loading