Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion python/sglang/srt/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -3411,7 +3411,9 @@ def get_rope(
elif "type" in rope_scaling:
scaling_type = rope_scaling["type"]
else:
raise ValueError("Unknown RoPE scaling type")
raise ValueError(
f"Unknown RoPE scaling type, rope_scaling is {rope_scaling}"
)

if scaling_type == "llama3":
scaling_factor = rope_scaling["factor"]
Expand Down
26 changes: 19 additions & 7 deletions python/sglang/srt/models/qwen3_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,8 +318,14 @@ def __init__(
super().__init__()
self.config = config
self.layer_id = layer_id

linear_attn_quant_config = (
None
if quant_config and quant_config.get_name() == "modelopt_fp4"
else quant_config
)
self.linear_attn = Qwen3_5GatedDeltaNet(
config, layer_id, quant_config, alt_stream, prefix
config, layer_id, linear_attn_quant_config, alt_stream, prefix
)

# NOTE: Determine the MLP type based on the model type
Expand Down Expand Up @@ -458,13 +464,19 @@ def __init__(
dtype=torch.get_default_dtype(),
)

attn_quant_config = (
None
if quant_config and quant_config.get_name() == "modelopt_fp4"
else quant_config
)

self.qkv_proj = QKVParallelLinear(
config.hidden_size,
self.head_dim,
self.total_num_heads * (1 + self.attn_output_gate),
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
quant_config=attn_quant_config,
tp_rank=self.attn_tp_rank,
tp_size=self.attn_tp_size,
prefix=add_prefix("qkv_proj", prefix),
Expand All @@ -474,7 +486,7 @@ def __init__(
self.total_num_heads * self.head_dim,
config.hidden_size,
bias=False,
quant_config=quant_config,
quant_config=attn_quant_config,
reduce_results=False,
tp_rank=self.attn_tp_rank,
tp_size=self.attn_tp_size,
Expand Down Expand Up @@ -1155,9 +1167,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
"_k_scale",
".v_scale",
"_v_scale",
".weight_scale",
"_weight_scale",
".input_scale",
"_input_scale",
)

Expand Down Expand Up @@ -1204,7 +1214,9 @@ def load_fused_expert_weights(
name = name.replace(".self_attn", "")

for param_name, weight_name, shard_id in stacked_params_mapping:
if "experts.gate_up_proj" in name or "experts.down_proj" in name:
if name.endswith("experts.gate_up_proj") or name.endswith(
"experts.down_proj"
):
is_fused_expert = True
expert_params_mapping = fused_expert_params_mapping

Expand Down Expand Up @@ -1274,7 +1286,7 @@ def load_fused_expert_weights(
num_experts,
)
else:
# Skip loading extra parameters for GPTQ/modelopt models.
# Skip loading extra parameters for GPTQ models.
if (
name_mapped.endswith(ignore_suffixes)
and name_mapped not in params_dict
Expand Down
4 changes: 4 additions & 0 deletions python/sglang/srt/models/qwen3_5_mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ def __init__(
if self.is_multimodal:
config = config.text_config

# The MTP model is unquantized in the nvfp4 checkpoint.
if quant_config and quant_config.get_name() == "modelopt_fp4":
quant_config = None

self.config = config
self.tp_size = get_tensor_model_parallel_world_size()
self.quant_config = quant_config
Expand Down
Loading