diff --git a/docs/advanced_features/dp_for_multi_modal_encoder.md b/docs/advanced_features/dp_for_multi_modal_encoder.md index 08d907b9fb78..62057f9581a0 100644 --- a/docs/advanced_features/dp_for_multi_modal_encoder.md +++ b/docs/advanced_features/dp_for_multi_modal_encoder.md @@ -27,3 +27,4 @@ python3 -m sglang.launch_server \ - Qwen2.5-VL () - Qwen3-VL () - InternVL () +- GLM-4.5V & GLM-4.6V () diff --git a/docs/basic_usage/glm45.md b/docs/basic_usage/glm45.md new file mode 100644 index 000000000000..d18b0a68d335 --- /dev/null +++ b/docs/basic_usage/glm45.md @@ -0,0 +1,70 @@ +## Launch GLM-4.5 / GLM-4.6 with SGLang + +To serve GLM-4.5 / GLM-4.6 FP8 models on 8xH100/H200 GPUs: + +```bash +python3 -m sglang.launch_server --model zai-org/GLM-4.6-FP8 --tp 8 +``` + +### Configuration Tips + +- `--max-mamba-cache-size`: Adjust `--max-mamba-cache-size` to increase mamba cache space and max running requests + capability. It will decrease KV cache space as a trade-off. You can adjust it according to workload. + +### EAGLE Speculative Decoding + +**Description**: SGLang has supported GLM-4.5 / GLM-4.6 models +with [EAGLE speculative decoding](https://docs.sglang.io/advanced_features/speculative_decoding.html#EAGLE-Decoding). + +**Usage**: +Add arguments `--speculative-algorithm`, `--speculative-num-steps`, `--speculative-eagle-topk` and +`--speculative-num-draft-tokens` to enable this feature. For example: + +``` bash +python3 -m sglang.launch_server \ + --model-path zai-org/GLM-4.6-FP8 \ + --tp-size 8 \ + --tool-call-parser glm45 \ + --reasoning-parser glm45 \ + --speculative-algorithm EAGLE \ + --speculative-num-steps 3 \ + --speculative-eagle-topk 1 \ + --speculative-num-draft-tokens 4 \ + --mem-fraction-static 0.9 \ + --served-model-name glm-4.6-fp8 \ + --enable-custom-logit-processor +``` + +### Thinking Budget for GLM-4.5 / GLM-4.6 + +In SGLang, we can implement thinking budget with `CustomLogitProcessor`. + +Launch a server with `--enable-custom-logit-processor` flag on. + +Sample Request: + +```python +import openai +from rich.pretty import pprint +from sglang.srt.sampling.custom_logit_processor import Glm4MoeThinkingBudgetLogitProcessor + + +client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="*") +response = client.chat.completions.create( + model="zai-org/GLM-4.6", + messages=[ + { + "role": "user", + "content": "Question: Is Paris the Capital of France?", + } + ], + max_tokens=1024, + extra_body={ + "custom_logit_processor": Glm4MoeThinkingBudgetLogitProcessor().to_str(), + "custom_params": { + "thinking_budget": 512, + }, + }, +) +pprint(response) +``` diff --git a/docs/basic_usage/glmv.md b/docs/basic_usage/glmv.md new file mode 100644 index 000000000000..c56b6ecd54cb --- /dev/null +++ b/docs/basic_usage/glmv.md @@ -0,0 +1,136 @@ +# GLM-4.6V / GLM-4.5V Usage + +## Launch commands for SGLang + +Below are suggested launch commands tailored for different hardware / precision modes + +### FP8 (quantised) mode + +For high memory-efficiency and latency optimized deployments (e.g., on H100, H200) where FP8 checkpoint is supported: + +```bash +python3 -m sglang.launch_server \ + --model-path zai-org/GLM-4.6V-FP8 \ + --tp 2 \ + --ep 2 \ + --host 0.0.0.0 \ + --port 30000 \ + --keep-mm-feature-on-device +``` + +### Non-FP8 (BF16 / full precision) mode +For deployments on A100/H100 where BF16 is used (or FP8 snapshot not used): +```bash +python3 -m sglang.launch_server \ + --model-path zai-org/GLM-4.6V \ + --tp 4 \ + --ep 4 \ + --host 0.0.0.0 \ + --port 30000 +``` + +## Hardware-specific notes / recommendations + +- On H100 with FP8: Use the FP8 checkpoint for best memory efficiency. +- On A100 / H100 with BF16 (non-FP8): It’s recommended to use `--mm-max-concurrent-calls` to control parallel throughput and GPU memory usage during image/video inference. +- On H200 & B200: The model can be run “out of the box”, supporting full context length plus concurrent image + video processing. + +## Sending Image/Video Requests + +### Image input: + +```python +import requests + +url = f"http://localhost:30000/v1/chat/completions" + +data = { + "model": "zai-org/GLM-4.6V", + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What’s in this image?"}, + { + "type": "image_url", + "image_url": { + "url": "https://github.com/sgl-project/sglang/blob/main/examples/assets/example_image.png?raw=true" + }, + }, + ], + } + ], + "max_tokens": 300, +} + +response = requests.post(url, json=data) +print(response.text) +``` + +### Video Input: + +```python +import requests + +url = f"http://localhost:30000/v1/chat/completions" + +data = { + "model": "zai-org/GLM-4.6V", + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What’s happening in this video?"}, + { + "type": "video_url", + "video_url": { + "url": "https://github.com/sgl-project/sgl-test-files/raw/refs/heads/main/videos/jobs_presenting_ipod.mp4" + }, + }, + ], + } + ], + "max_tokens": 300, +} + +response = requests.post(url, json=data) +print(response.text) +``` + +## Important Server Parameters and Flags + +When launching the model server for **multimodal support**, you can use the following command-line arguments to fine-tune performance and behavior: + +- `--mm-attention-backend`: Specify multimodal attention backend. Eg. `fa3`(Flash Attention 3) +- `--mm-max-concurrent-calls `: Specifies the **maximum number of concurrent asynchronous multimodal data processing calls** allowed on the server. Use this to control parallel throughput and GPU memory usage during image/video inference. +- `--mm-per-request-timeout `: Defines the **timeout duration (in seconds)** for each multimodal request. If a request exceeds this time limit (e.g., for very large video inputs), it will be automatically terminated. +- `--keep-mm-feature-on-device`: Instructs the server to **retain multimodal feature tensors on the GPU** after processing. This avoids device-to-host (D2H) memory copies and improves performance for repeated or high-frequency inference workloads. +- `--mm-enable-dp-encoder`: Placing the ViT in data parallel while keeping the LLM in tensor parallel consistently lowers TTFT and boosts end-to-end throughput. +- `SGLANG_USE_CUDA_IPC_TRANSPORT=1`: Shared memory pool based CUDA IPC for multi-modal data transport. For significantly improving e2e latency. + +### Example usage with the above optimizations: +```bash +SGLANG_USE_CUDA_IPC_TRANSPORT=1 \ +SGLANG_VLM_CACHE_SIZE_MB=0 \ +python -m sglang.launch_server \ + --model-path zai-org/GLM-4.6V \ + --host 0.0.0.0 \ + --port 30000 \ + --trust-remote-code \ + --tp-size 8 \ + --enable-cache-report \ + --log-level info \ + --max-running-requests 64 \ + --mem-fraction-static 0.65 \ + --chunked-prefill-size 8192 \ + --attention-backend fa3 \ + --mm-attention-backend fa3 \ + --mm-enable-dp-encoder \ + --enable-metrics +``` + +### Thinking Budget for GLM-4.5V / GLM-4.6V + +In SGLang, we can implement thinking budget with `CustomLogitProcessor`. + +Launch a server with `--enable-custom-logit-processor` flag on. and using `Glm4MoeThinkingBudgetLogitProcessor` in the request likes `GLM-4.6` example in [glm45.md](./glm45.md). diff --git a/docs/basic_usage/popular_model_usage.rst b/docs/basic_usage/popular_model_usage.rst index b8e75f2180be..db70118aed00 100644 --- a/docs/basic_usage/popular_model_usage.rst +++ b/docs/basic_usage/popular_model_usage.rst @@ -1,4 +1,4 @@ -Popular Model Usage (DeepSeek, GPT-OSS, Llama, Qwen, and more) +Popular Model Usage (DeepSeek, GPT-OSS, GLM, Llama, Qwen, and more) =============================================================== .. toctree:: @@ -6,6 +6,8 @@ Popular Model Usage (DeepSeek, GPT-OSS, Llama, Qwen, and more) deepseek_v3.md deepseek_v32.md + glm45.md + glmv.md gpt_oss.md qwen3.md qwen3_vl.md diff --git a/python/sglang/srt/models/glm4v_moe.py b/python/sglang/srt/models/glm4v_moe.py index 8ec27dc0e153..ce1a09e45b40 100644 --- a/python/sglang/srt/models/glm4v_moe.py +++ b/python/sglang/srt/models/glm4v_moe.py @@ -7,6 +7,7 @@ from transformers.models.glm4v_moe.configuration_glm4v_moe import Glm4vMoeConfig from sglang.srt.distributed import get_tensor_model_parallel_world_size +from sglang.srt.distributed.parallel_state import get_pp_group from sglang.srt.layers.attention import vision_utils from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE @@ -36,7 +37,9 @@ def __init__( ) -> None: nn.Module.__init__(self) + self.pp_group = get_pp_group() self.config = config + self.use_data_parallel = get_global_server_args().mm_enable_dp_encoder vision_utils.update_vit_attn_dummy_heads_config(self.config) self.tp_size = get_tensor_model_parallel_world_size() self.quant_config = quant_config @@ -55,6 +58,7 @@ def __init__( config.vision_config, quant_config=quant_config, prefix=add_prefix("visual", prefix), + use_data_parallel=self.use_data_parallel, ) self.lm_head = ParallelLMHead( diff --git a/python/sglang/srt/sampling/custom_logit_processor.py b/python/sglang/srt/sampling/custom_logit_processor.py index 9dfdff75cf1d..d58a5f6cf149 100644 --- a/python/sglang/srt/sampling/custom_logit_processor.py +++ b/python/sglang/srt/sampling/custom_logit_processor.py @@ -112,6 +112,14 @@ def __call__(self, logits, custom_param_list: list[dict[str, Any]]): return logits +class Glm4MoeThinkingBudgetLogitProcessor(ThinkingBudgetLogitProcessor): + """A logit processor that controls the length of thinking for GLM-4.5 / GLM-4.6 / GLM-4.5V / GLM-4.6V models.""" + + THINKING_START_TOKEN_ID: int = 151350 + THINKING_END_TOKEN_ID: int = 151351 + NEW_LINE_TOKEN_ID: int = 198 + + class Qwen3ThinkingBudgetLogitProcessor(ThinkingBudgetLogitProcessor): """A logit processor that controls the length of thinking for Qwen3 models."""