diff --git a/docs/advanced_features/dp_for_multi_modal_encoder.md b/docs/advanced_features/dp_for_multi_modal_encoder.md index 08d907b9fb78..62057f9581a0 100644 --- a/docs/advanced_features/dp_for_multi_modal_encoder.md +++ b/docs/advanced_features/dp_for_multi_modal_encoder.md @@ -27,3 +27,4 @@ python3 -m sglang.launch_server \ - Qwen2.5-VL () - Qwen3-VL () - InternVL () +- GLM-4.5V & GLM-4.6V () diff --git a/docs/basic_usage/glm45.md b/docs/basic_usage/glm45.md new file mode 100644 index 000000000000..d18b0a68d335 --- /dev/null +++ b/docs/basic_usage/glm45.md @@ -0,0 +1,70 @@ +## Launch GLM-4.5 / GLM-4.6 with SGLang + +To serve GLM-4.5 / GLM-4.6 FP8 models on 8xH100/H200 GPUs: + +```bash +python3 -m sglang.launch_server --model zai-org/GLM-4.6-FP8 --tp 8 +``` + +### Configuration Tips + +- `--max-mamba-cache-size`: Adjust `--max-mamba-cache-size` to increase mamba cache space and max running requests + capability. It will decrease KV cache space as a trade-off. You can adjust it according to workload. + +### EAGLE Speculative Decoding + +**Description**: SGLang has supported GLM-4.5 / GLM-4.6 models +with [EAGLE speculative decoding](https://docs.sglang.io/advanced_features/speculative_decoding.html#EAGLE-Decoding). + +**Usage**: +Add arguments `--speculative-algorithm`, `--speculative-num-steps`, `--speculative-eagle-topk` and +`--speculative-num-draft-tokens` to enable this feature. For example: + +``` bash +python3 -m sglang.launch_server \ + --model-path zai-org/GLM-4.6-FP8 \ + --tp-size 8 \ + --tool-call-parser glm45 \ + --reasoning-parser glm45 \ + --speculative-algorithm EAGLE \ + --speculative-num-steps 3 \ + --speculative-eagle-topk 1 \ + --speculative-num-draft-tokens 4 \ + --mem-fraction-static 0.9 \ + --served-model-name glm-4.6-fp8 \ + --enable-custom-logit-processor +``` + +### Thinking Budget for GLM-4.5 / GLM-4.6 + +In SGLang, we can implement thinking budget with `CustomLogitProcessor`. + +Launch a server with `--enable-custom-logit-processor` flag on. + +Sample Request: + +```python +import openai +from rich.pretty import pprint +from sglang.srt.sampling.custom_logit_processor import Glm4MoeThinkingBudgetLogitProcessor + + +client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="*") +response = client.chat.completions.create( + model="zai-org/GLM-4.6", + messages=[ + { + "role": "user", + "content": "Question: Is Paris the Capital of France?", + } + ], + max_tokens=1024, + extra_body={ + "custom_logit_processor": Glm4MoeThinkingBudgetLogitProcessor().to_str(), + "custom_params": { + "thinking_budget": 512, + }, + }, +) +pprint(response) +``` diff --git a/docs/basic_usage/glmv.md b/docs/basic_usage/glmv.md new file mode 100644 index 000000000000..c56b6ecd54cb --- /dev/null +++ b/docs/basic_usage/glmv.md @@ -0,0 +1,136 @@ +# GLM-4.6V / GLM-4.5V Usage + +## Launch commands for SGLang + +Below are suggested launch commands tailored for different hardware / precision modes + +### FP8 (quantised) mode + +For high memory-efficiency and latency optimized deployments (e.g., on H100, H200) where FP8 checkpoint is supported: + +```bash +python3 -m sglang.launch_server \ + --model-path zai-org/GLM-4.6V-FP8 \ + --tp 2 \ + --ep 2 \ + --host 0.0.0.0 \ + --port 30000 \ + --keep-mm-feature-on-device +``` + +### Non-FP8 (BF16 / full precision) mode +For deployments on A100/H100 where BF16 is used (or FP8 snapshot not used): +```bash +python3 -m sglang.launch_server \ + --model-path zai-org/GLM-4.6V \ + --tp 4 \ + --ep 4 \ + --host 0.0.0.0 \ + --port 30000 +``` + +## Hardware-specific notes / recommendations + +- On H100 with FP8: Use the FP8 checkpoint for best memory efficiency. +- On A100 / H100 with BF16 (non-FP8): It’s recommended to use `--mm-max-concurrent-calls` to control parallel throughput and GPU memory usage during image/video inference. +- On H200 & B200: The model can be run “out of the box”, supporting full context length plus concurrent image + video processing. + +## Sending Image/Video Requests + +### Image input: + +```python +import requests + +url = f"http://localhost:30000/v1/chat/completions" + +data = { + "model": "zai-org/GLM-4.6V", + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What’s in this image?"}, + { + "type": "image_url", + "image_url": { + "url": "https://github.com/sgl-project/sglang/blob/main/examples/assets/example_image.png?raw=true" + }, + }, + ], + } + ], + "max_tokens": 300, +} + +response = requests.post(url, json=data) +print(response.text) +``` + +### Video Input: + +```python +import requests + +url = f"http://localhost:30000/v1/chat/completions" + +data = { + "model": "zai-org/GLM-4.6V", + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What’s happening in this video?"}, + { + "type": "video_url", + "video_url": { + "url": "https://github.com/sgl-project/sgl-test-files/raw/refs/heads/main/videos/jobs_presenting_ipod.mp4" + }, + }, + ], + } + ], + "max_tokens": 300, +} + +response = requests.post(url, json=data) +print(response.text) +``` + +## Important Server Parameters and Flags + +When launching the model server for **multimodal support**, you can use the following command-line arguments to fine-tune performance and behavior: + +- `--mm-attention-backend`: Specify multimodal attention backend. Eg. `fa3`(Flash Attention 3) +- `--mm-max-concurrent-calls `: Specifies the **maximum number of concurrent asynchronous multimodal data processing calls** allowed on the server. Use this to control parallel throughput and GPU memory usage during image/video inference. +- `--mm-per-request-timeout `: Defines the **timeout duration (in seconds)** for each multimodal request. If a request exceeds this time limit (e.g., for very large video inputs), it will be automatically terminated. +- `--keep-mm-feature-on-device`: Instructs the server to **retain multimodal feature tensors on the GPU** after processing. This avoids device-to-host (D2H) memory copies and improves performance for repeated or high-frequency inference workloads. +- `--mm-enable-dp-encoder`: Placing the ViT in data parallel while keeping the LLM in tensor parallel consistently lowers TTFT and boosts end-to-end throughput. +- `SGLANG_USE_CUDA_IPC_TRANSPORT=1`: Shared memory pool based CUDA IPC for multi-modal data transport. For significantly improving e2e latency. + +### Example usage with the above optimizations: +```bash +SGLANG_USE_CUDA_IPC_TRANSPORT=1 \ +SGLANG_VLM_CACHE_SIZE_MB=0 \ +python -m sglang.launch_server \ + --model-path zai-org/GLM-4.6V \ + --host 0.0.0.0 \ + --port 30000 \ + --trust-remote-code \ + --tp-size 8 \ + --enable-cache-report \ + --log-level info \ + --max-running-requests 64 \ + --mem-fraction-static 0.65 \ + --chunked-prefill-size 8192 \ + --attention-backend fa3 \ + --mm-attention-backend fa3 \ + --mm-enable-dp-encoder \ + --enable-metrics +``` + +### Thinking Budget for GLM-4.5V / GLM-4.6V + +In SGLang, we can implement thinking budget with `CustomLogitProcessor`. + +Launch a server with `--enable-custom-logit-processor` flag on. and using `Glm4MoeThinkingBudgetLogitProcessor` in the request likes `GLM-4.6` example in [glm45.md](./glm45.md). diff --git a/docs/basic_usage/popular_model_usage.rst b/docs/basic_usage/popular_model_usage.rst index b8e75f2180be..db70118aed00 100644 --- a/docs/basic_usage/popular_model_usage.rst +++ b/docs/basic_usage/popular_model_usage.rst @@ -1,4 +1,4 @@ -Popular Model Usage (DeepSeek, GPT-OSS, Llama, Qwen, and more) +Popular Model Usage (DeepSeek, GPT-OSS, GLM, Llama, Qwen, and more) =============================================================== .. toctree:: @@ -6,6 +6,8 @@ Popular Model Usage (DeepSeek, GPT-OSS, Llama, Qwen, and more) deepseek_v3.md deepseek_v32.md + glm45.md + glmv.md gpt_oss.md qwen3.md qwen3_vl.md diff --git a/python/pyproject.toml b/python/pyproject.toml index 85428cba02d8..a0663d631794 100755 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -65,7 +65,7 @@ dependencies = [ "torch_memory_saver==0.0.9", "torch==2.9.1", "torchaudio==2.9.1", - "torchcodec==0.7.0 ; sys_platform != 'linux' or (sys_platform == 'linux' and platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')", # torchcodec does not exist in those systems. If not provided, transformer will use torchvision instead by default. + "torchcodec==0.8.0 ; sys_platform != 'linux' or (sys_platform == 'linux' and platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')", # torchcodec does not exist in those systems. If not provided, transformer will use torchvision instead by default. "torchvision", "torchao==0.9.0", "tqdm", diff --git a/python/sglang/srt/configs/qwen3_omni.py b/python/sglang/srt/configs/qwen3_omni.py index d42e98a9a07b..8baea892335d 100644 --- a/python/sglang/srt/configs/qwen3_omni.py +++ b/python/sglang/srt/configs/qwen3_omni.py @@ -1,6 +1,5 @@ from transformers import PretrainedConfig from transformers.configuration_utils import layer_type_validation -from transformers.modeling_rope_utils import rope_config_validation from sglang.utils import logger @@ -168,7 +167,6 @@ def __init__( # BC: if there is a 'type' field, move it to 'rope_type'. if self.rope_scaling is not None and "type" in self.rope_scaling: self.rope_scaling["rope_type"] = self.rope_scaling["type"] - rope_config_validation(self) # MoE arguments self.decoder_sparse_step = decoder_sparse_step @@ -311,7 +309,6 @@ def __init__( # BC: if there is a 'type' field, move it to 'rope_type'. if self.rope_scaling is not None and "type" in self.rope_scaling: self.rope_scaling["rope_type"] = self.rope_scaling["type"] - rope_config_validation(self) self.layer_types = layer_types if self.layer_types is None: @@ -405,7 +402,6 @@ def __init__( # BC: if there is a 'type' field, move it to 'rope_type'. if self.rope_scaling is not None and "type" in self.rope_scaling: self.rope_scaling["rope_type"] = self.rope_scaling["type"] - rope_config_validation(self) # MoE arguments self.decoder_sparse_step = decoder_sparse_step diff --git a/python/sglang/srt/configs/qwen3_vl.py b/python/sglang/srt/configs/qwen3_vl.py index a758d1f4e45e..85068b5a6002 100644 --- a/python/sglang/srt/configs/qwen3_vl.py +++ b/python/sglang/srt/configs/qwen3_vl.py @@ -1,5 +1,4 @@ from transformers import PretrainedConfig -from transformers.modeling_rope_utils import rope_config_validation class Qwen3VLVisionConfig(PretrainedConfig): @@ -187,8 +186,6 @@ def __init__( self.attention_bias = attention_bias self.attention_dropout = attention_dropout - rope_config_validation(self, ignore_keys={"mrope_section", "mrope_interleaved"}) - super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) @@ -450,8 +447,6 @@ def __init__( self.rope_scaling = rope_scaling self.head_dim = head_dim or hidden_size // num_attention_heads - rope_config_validation(self, ignore_keys={"mrope_section", "mrope_interleaved"}) - # MoE arguments self.decoder_sparse_step = decoder_sparse_step self.moe_intermediate_size = moe_intermediate_size diff --git a/python/sglang/srt/models/glm4_moe.py b/python/sglang/srt/models/glm4_moe.py index a9689b8f2754..280f5602c130 100644 --- a/python/sglang/srt/models/glm4_moe.py +++ b/python/sglang/srt/models/glm4_moe.py @@ -361,6 +361,7 @@ def __init__( if get_global_server_args().disable_shared_experts_fusion else config.n_shared_experts ) + self.config = config self.layer_id = layer_id self.alt_stream = alt_stream diff --git a/python/sglang/srt/models/glm4v.py b/python/sglang/srt/models/glm4v.py index ddce004026fc..ae6edb6319e7 100644 --- a/python/sglang/srt/models/glm4v.py +++ b/python/sglang/srt/models/glm4v.py @@ -123,6 +123,7 @@ def __init__( num_heads: int, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + attn_qkv_bias: bool = True, num_dummy_heads: int = 0, rms_norm_eps: float = 1e-5, use_data_parallel: bool = False, @@ -136,7 +137,8 @@ def __init__( num_heads=num_heads, projection_size=dim, use_qkv_parallel=True, - proj_bias=True, + proj_bias=False, + qkv_bias=attn_qkv_bias, flatten_batch=True, quant_config=quant_config, prefix=add_prefix("attn", prefix), @@ -440,6 +442,7 @@ def __init__( quant_config=quant_config, prefix=add_prefix(f"blocks.{layer_idx}", prefix), rms_norm_eps=vision_config.rms_norm_eps, + attn_qkv_bias=vision_config.attention_bias, use_data_parallel=use_data_parallel, ) for layer_idx in range(depth) @@ -623,14 +626,27 @@ def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: self.visual.dtype ) video_grid_thw = torch.concat([item.video_grid_thw for item in items], dim=0) + + # reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames + temp_frames_hw = [] + for t, h, w in video_grid_thw: + repeated_row = ( + torch.tensor([1, h.item(), w.item()]).unsqueeze(0).repeat(t, 1) + ) + temp_frames_hw.append(repeated_row) + flattened_video_grid_thw = torch.cat(temp_frames_hw, dim=0) + assert pixel_values.dim() == 2, pixel_values.dim() assert video_grid_thw.dim() == 2, video_grid_thw.dim() if self.use_data_parallel: return run_dp_sharded_mrope_vision_model( - self.visual, pixel_values, video_grid_thw.tolist(), rope_type="rope_3d" + self.visual, + pixel_values, + flattened_video_grid_thw.tolist(), + rope_type="rope_3d", ) else: - video_embeds = self.visual(pixel_values, grid_thw=video_grid_thw) + video_embeds = self.visual(pixel_values, grid_thw=flattened_video_grid_thw) return video_embeds def get_input_embeddings(self): diff --git a/python/sglang/srt/models/glm4v_moe.py b/python/sglang/srt/models/glm4v_moe.py index 8ec27dc0e153..324de18b49b7 100644 --- a/python/sglang/srt/models/glm4v_moe.py +++ b/python/sglang/srt/models/glm4v_moe.py @@ -6,21 +6,28 @@ import torch.nn as nn from transformers.models.glm4v_moe.configuration_glm4v_moe import Glm4vMoeConfig -from sglang.srt.distributed import get_tensor_model_parallel_world_size +from sglang.srt.distributed import ( + get_moe_expert_parallel_world_size, + get_tensor_model_parallel_world_size, +) +from sglang.srt.distributed.parallel_state import get_pp_group from sglang.srt.layers.attention import vision_utils from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.moe import get_moe_a2a_backend from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.utils import PPMissingLayer from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.glm4_moe import Glm4MoeModel from sglang.srt.models.glm4v import Glm4vForConditionalGeneration, Glm4vVisionModel from sglang.srt.server_args import get_global_server_args -from sglang.srt.utils import add_prefix, is_cuda +from sglang.srt.utils import add_prefix, get_device_sm, is_cuda, log_info_on_rank0 from sglang.srt.utils.hf_transformers_utils import get_processor _is_cuda = is_cuda() +_device_sm = get_device_sm() logger = logging.getLogger(__name__) @@ -36,15 +43,14 @@ def __init__( ) -> None: nn.Module.__init__(self) + self.pp_group = get_pp_group() self.config = config + self.use_data_parallel = get_global_server_args().mm_enable_dp_encoder vision_utils.update_vit_attn_dummy_heads_config(self.config) self.tp_size = get_tensor_model_parallel_world_size() self.quant_config = quant_config - self.num_fused_shared_experts = ( - 0 - if get_global_server_args().disable_shared_experts_fusion - else config.n_shared_experts - ) + self.num_fused_shared_experts = 0 + self.determine_num_fused_shared_experts() self.model = Glm4MoeModel( config, @@ -55,15 +61,24 @@ def __init__( config.vision_config, quant_config=quant_config, prefix=add_prefix("visual", prefix), + use_data_parallel=self.use_data_parallel, ) - self.lm_head = ParallelLMHead( - config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=add_prefix("lm_head", prefix), - use_attn_tp_group=get_global_server_args().enable_dp_lm_head, - ) + if self.pp_group.is_last_rank: + if self.pp_group.world_size == 1 and self.config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=add_prefix("lm_head", prefix), + use_attn_tp_group=get_global_server_args().enable_dp_lm_head, + ) + else: + # ranks other than the last rank will have a placeholder layer + self.lm_head = PPMissingLayer() + self.logits_processor = LogitsProcessor(config) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling @@ -71,6 +86,36 @@ def __init__( # For EAGLE3 support self.capture_aux_hidden_states = False + def determine_num_fused_shared_experts(self): + if get_global_server_args().disable_shared_experts_fusion: + return + + disable_reason = None + if not getattr(self.config, "n_shared_experts", None): + disable_reason = "No shared experts are defined in the config." + elif not _is_cuda: + disable_reason = "Shared experts fusion currently requires CUDA devices." + elif _is_cuda and (_device_sm is not None) and (_device_sm < 80): + disable_reason = "Shared experts fusion requires SM80 or newer GPUs." + elif get_moe_expert_parallel_world_size() > 1: + disable_reason = "Shared experts fusion is not supported together with expert parallelism yet." + elif get_moe_a2a_backend().is_deepep(): + disable_reason = "Shared experts fusion is not supported when Deepep MoE backend is enabled." + + if disable_reason is not None: + get_global_server_args().disable_shared_experts_fusion = True + log_info_on_rank0( + logger, + f"{disable_reason} Shared experts fusion optimization is disabled.", + ) + return + + self.num_fused_shared_experts = self.config.n_shared_experts + assert ( + self.num_fused_shared_experts == 1 + ), "Only 1 fused shared expert is supported for Glm4vMoeForConditionalGeneration" + log_info_on_rank0(logger, "Shared experts fusion optimization enabled.") + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False): if is_nextn: if hasattr(self.config, "num_nextn_predict_layers"): @@ -98,7 +143,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=Fal ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", - num_experts=self.config.n_routed_experts, + num_experts=self.config.n_routed_experts + self.num_fused_shared_experts, ) if is_nextn: @@ -115,6 +160,13 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=Fal for name, loaded_weight in weights: weight_names.append(name) + if self.num_fused_shared_experts > 0 and "mlp.shared_experts" in name: + # Shared expert becomes expert ID = n_routed_experts + name = name.replace( + "mlp.shared_experts", + f"mlp.experts.{self.config.n_routed_experts}", + ) + if not is_nextn: if hasattr(self.config, "num_nextn_predict_layers"): num_nextn_layers = self.config.num_nextn_predict_layers @@ -150,6 +202,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=Fal name = name.replace("model.visual.", "visual.") if "rotary_emb.inv_freq" in name: continue + for param_name, weight_name, shard_id in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). if weight_name not in name: diff --git a/python/sglang/srt/sampling/custom_logit_processor.py b/python/sglang/srt/sampling/custom_logit_processor.py index 9dfdff75cf1d..d58a5f6cf149 100644 --- a/python/sglang/srt/sampling/custom_logit_processor.py +++ b/python/sglang/srt/sampling/custom_logit_processor.py @@ -112,6 +112,14 @@ def __call__(self, logits, custom_param_list: list[dict[str, Any]]): return logits +class Glm4MoeThinkingBudgetLogitProcessor(ThinkingBudgetLogitProcessor): + """A logit processor that controls the length of thinking for GLM-4.5 / GLM-4.6 / GLM-4.5V / GLM-4.6V models.""" + + THINKING_START_TOKEN_ID: int = 151350 + THINKING_END_TOKEN_ID: int = 151351 + NEW_LINE_TOKEN_ID: int = 198 + + class Qwen3ThinkingBudgetLogitProcessor(ThinkingBudgetLogitProcessor): """A logit processor that controls the length of thinking for Qwen3 models.""" diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index 6fa0b2404ba0..f6b06ca87306 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -2716,6 +2716,7 @@ def is_fa3_default_architecture(hf_config): "Qwen3ForCausalLM", "Qwen3MoeForCausalLM", "Glm4MoeForCausalLM", + "Glm4vForConditionalGeneration", "Glm4vMoeForConditionalGeneration", "Step3VLForConditionalGeneration", }