diff --git a/python/sglang/srt/models/qwen3_next.py b/python/sglang/srt/models/qwen3_next.py index eed21a9bd6ab..880201a7d8a9 100644 --- a/python/sglang/srt/models/qwen3_next.py +++ b/python/sglang/srt/models/qwen3_next.py @@ -665,6 +665,7 @@ def __init__( self.scaling, num_kv_heads=self.num_kv_heads, layer_id=layer_id, + quant_config=quant_config, prefix=f"{prefix}.attn", ) @@ -921,6 +922,15 @@ class HybridLayerType(enum.Enum): class Qwen3NextForCausalLM(nn.Module): fall_back_to_pt_during_load = False + # Map fused module names to their checkpoint (unfused) counterparts. + # This is needed so the quantization exclusion logic can match + # checkpoint-style names (e.g. "q_proj") against the fused sglang + # module names (e.g. "qkv_proj"). + packed_modules_mapping = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], + } + def __init__( self, config: Qwen3NextConfig, @@ -931,6 +941,14 @@ def __init__( self.config = config self.pp_group = get_pp_group() assert self.pp_group.is_first_rank and self.pp_group.is_last_rank + + # The quant config's packed_modules_mapping may be None if it wasn't + # in the checkpoint config. The base class (QuantizationConfig) intends + # for models to set this. We need it so is_layer_skipped can unfuse + # "qkv_proj" into ["q_proj","k_proj","v_proj"] when checking exclusions. + if quant_config is not None and hasattr(quant_config, "packed_modules_mapping"): + quant_config.packed_modules_mapping = self.packed_modules_mapping + self.quant_config = quant_config self.model = Qwen3NextModel( config, quant_config, prefix=add_prefix("model", prefix) @@ -1052,6 +1070,14 @@ def load_weights( if ".self_attn." in name: name = name.replace(".self_attn", "") + # Remap modelopt FP8 KV cache scale names: + # checkpoint: k_proj.k_scale / v_proj.v_scale + # model: attn.k_scale / attn.v_scale + if name.endswith(".k_proj.k_scale"): + name = name.replace(".k_proj.k_scale", ".attn.k_scale") + elif name.endswith(".v_proj.v_scale"): + name = name.replace(".v_proj.v_scale", ".attn.v_scale") + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue @@ -1060,15 +1086,16 @@ def load_weights( if "mlp.experts" in name: continue - name = name.replace(weight_name, param_name) + replaced_name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: + if replaced_name.endswith(".bias") and replaced_name not in params_dict: continue # Skip layers on other devices. # if is_pp_missing_parameter(name, self): # continue - if name not in params_dict: + if replaced_name not in params_dict: continue + name = replaced_name param = params_dict[name] weight_loader = getattr(param, "weight_loader") weight_loader(param, loaded_weight, shard_id) @@ -1078,15 +1105,17 @@ def load_weights( param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue - name = name.replace(weight_name, param_name) + replaced_name = name.replace(weight_name, param_name) # Skip layers on other devices. # if is_pp_missing_parameter(name, self): # continue # Skip loading extra bias for GPTQ models. if ( - name.endswith(".bias") or name.endswith("_bias") - ) and name not in params_dict: + replaced_name.endswith(".bias") + or replaced_name.endswith("_bias") + ) and replaced_name not in params_dict: continue + name = replaced_name param = params_dict[name] weight_loader = getattr(param, "weight_loader")