Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions python/sglang/srt/layers/quantization/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,12 +173,10 @@ def _modelopt_override_quantization_method(
# Check if this is a ModelOpt config
quant_algo = hf_quant_config.get("quant_algo", "").upper()

# If user specified generic "modelopt", auto-detect the specific method
if user_quant == "modelopt":
if "FP8" in quant_algo:
return "modelopt_fp8"
elif "NVFP4" in quant_algo or "FP4" in quant_algo:
return "modelopt_fp4"
if "FP8" in quant_algo:
return "modelopt_fp8"
elif "NVFP4" in quant_algo or "FP4" in quant_algo:
return "modelopt_fp4"

# The hf_quant_config may be a parsed quant config, so we need to check the
# quant_method.
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/layers/quantization/modelopt_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -1513,7 +1513,7 @@ def _slice_scale(w):
# For other backends, ensure the per-input block dimension is aligned to 16.
assert (
weight_scale.shape[assert_dim] % block_size == 0
), f"Expected {name}_weight_scale.dim({assert_dim}) to be divisible by {block_size}"
), f"Expected {name}_weight_scale.dim({assert_dim}) to be divisible by {block_size}, got {weight_scale.shape[assert_dim]}"
assert (
weight_scale.dtype == torch.float8_e4m3fn
), f"{name} Weight Blockscale must be represented as FP8-E4M3"
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/models/glm4_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.moe.utils import RoutingMethodType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
from sglang.srt.layers.radix_attention import RadixAttention
Expand Down Expand Up @@ -375,6 +376,7 @@ def __init__(
intermediate_size=config.moe_intermediate_size,
quant_config=quant_config,
routed_scaling_factor=self.routed_scaling_factor,
routing_method_type=RoutingMethodType.DeepSeekV3,
prefix=add_prefix("experts", prefix),
)

Expand Down
22 changes: 16 additions & 6 deletions python/sglang/srt/models/glm4_moe_nextn.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
from torch import nn
from transformers import PretrainedConfig

from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size
from sglang.srt.environ import envs
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.layers.dp_attention import is_dp_attention_enabled
from sglang.srt.layers.layernorm import RMSNorm
Expand All @@ -34,11 +35,14 @@
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.models.glm4_moe import Glm4MoeDecoderLayer, Glm4MoeForCausalLM
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import add_prefix
from sglang.srt.utils import add_prefix, is_cuda

logger = logging.getLogger(__name__)


_is_cuda = is_cuda()


class Glm4MoeModelNextN(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -67,12 +71,19 @@ def __init__(

self.eh_proj = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=False)

self.alt_stream = (
torch.cuda.Stream()
if _is_cuda or envs.SGLANG_NPU_USE_MULTI_STREAM.get()
else None
)

self.decoder = Glm4MoeDecoderLayer(
config,
0,
quant_config=quant_config,
is_nextn=True,
prefix=add_prefix("decoder", prefix),
alt_stream=self.alt_stream,
)

self.shared_head = nn.Module()
Expand Down Expand Up @@ -127,6 +138,9 @@ def __init__(
self.config = config
self.tp_size = get_tensor_model_parallel_world_size()
self.quant_config = quant_config
self.pp_group = get_pp_group()
self.num_fused_shared_experts = 0
self.determine_num_fused_shared_experts()
self.model = Glm4MoeModelNextN(
config, quant_config, prefix=add_prefix("model", prefix)
)
Expand All @@ -139,10 +153,6 @@ def __init__(
)
self.logits_processor = LogitsProcessor(config)

self.num_fused_shared_experts = (
0 if get_global_server_args().disable_shared_experts_fusion else 1
)

@torch.no_grad()
def forward(
self,
Expand Down
Loading