Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/advanced_features/dp_for_multi_modal_encoder.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@ python3 -m sglang.launch_server \
- Qwen2.5-VL (<https://github.com/sgl-project/sglang/pull/13126>)
- Qwen3-VL (<https://github.com/sgl-project/sglang/pull/13724>)
- InternVL (<https://github.com/sgl-project/sglang/pull/13925>)
- GLM-4.5V & GLM-4.6V (<https://github.com/sgl-project/sglang/pull/14097>)
70 changes: 70 additions & 0 deletions docs/basic_usage/glm45.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
## Launch GLM-4.5 / GLM-4.6 with SGLang

To serve GLM-4.5 / GLM-4.6 FP8 models on 8xH100/H200 GPUs:

```bash
python3 -m sglang.launch_server --model zai-org/GLM-4.6-FP8 --tp 8
```

### Configuration Tips

- `--max-mamba-cache-size`: Adjust `--max-mamba-cache-size` to increase mamba cache space and max running requests
capability. It will decrease KV cache space as a trade-off. You can adjust it according to workload.

### EAGLE Speculative Decoding

**Description**: SGLang has supported GLM-4.5 / GLM-4.6 models
with [EAGLE speculative decoding](https://docs.sglang.io/advanced_features/speculative_decoding.html#EAGLE-Decoding).

**Usage**:
Add arguments `--speculative-algorithm`, `--speculative-num-steps`, `--speculative-eagle-topk` and
`--speculative-num-draft-tokens` to enable this feature. For example:

``` bash
python3 -m sglang.launch_server \
--model-path zai-org/GLM-4.6-FP8 \
--tp-size 8 \
--tool-call-parser glm45 \
--reasoning-parser glm45 \
--speculative-algorithm EAGLE \
--speculative-num-steps 3 \
--speculative-eagle-topk 1 \
--speculative-num-draft-tokens 4 \
--mem-fraction-static 0.9 \
--served-model-name glm-4.6-fp8 \
--enable-custom-logit-processor
```

### Thinking Budget for GLM-4.5 / GLM-4.6

In SGLang, we can implement thinking budget with `CustomLogitProcessor`.

Launch a server with `--enable-custom-logit-processor` flag on.

Sample Request:

```python
import openai
from rich.pretty import pprint
from sglang.srt.sampling.custom_logit_processor import Glm4MoeThinkingBudgetLogitProcessor


client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="*")
response = client.chat.completions.create(
model="zai-org/GLM-4.6",
messages=[
{
"role": "user",
"content": "Question: Is Paris the Capital of France?",
}
],
max_tokens=1024,
extra_body={
"custom_logit_processor": Glm4MoeThinkingBudgetLogitProcessor().to_str(),
"custom_params": {
"thinking_budget": 512,
},
},
)
pprint(response)
```
136 changes: 136 additions & 0 deletions docs/basic_usage/glmv.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# GLM-4.6V / GLM-4.5V Usage

## Launch commands for SGLang

Below are suggested launch commands tailored for different hardware / precision modes

### FP8 (quantised) mode

For high memory-efficiency and latency optimized deployments (e.g., on H100, H200) where FP8 checkpoint is supported:

```bash
python3 -m sglang.launch_server \
--model-path zai-org/GLM-4.6V-FP8 \
--tp 2 \
--ep 2 \
--host 0.0.0.0 \
--port 30000 \
--keep-mm-feature-on-device
```

### Non-FP8 (BF16 / full precision) mode
For deployments on A100/H100 where BF16 is used (or FP8 snapshot not used):
```bash
python3 -m sglang.launch_server \
--model-path zai-org/GLM-4.6V \
--tp 4 \
--ep 4 \
--host 0.0.0.0 \
--port 30000
```

## Hardware-specific notes / recommendations

- On H100 with FP8: Use the FP8 checkpoint for best memory efficiency.
- On A100 / H100 with BF16 (non-FP8): It’s recommended to use `--mm-max-concurrent-calls` to control parallel throughput and GPU memory usage during image/video inference.
- On H200 & B200: The model can be run “out of the box”, supporting full context length plus concurrent image + video processing.

## Sending Image/Video Requests

### Image input:

```python
import requests

url = f"http://localhost:30000/v1/chat/completions"

data = {
"model": "zai-org/GLM-4.6V",
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "What’s in this image?"},
{
"type": "image_url",
"image_url": {
"url": "https://github.com/sgl-project/sglang/blob/main/examples/assets/example_image.png?raw=true"
},
},
],
}
],
"max_tokens": 300,
}

response = requests.post(url, json=data)
print(response.text)
```

### Video Input:

```python
import requests

url = f"http://localhost:30000/v1/chat/completions"

data = {
"model": "zai-org/GLM-4.6V",
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "What’s happening in this video?"},
{
"type": "video_url",
"video_url": {
"url": "https://github.com/sgl-project/sgl-test-files/raw/refs/heads/main/videos/jobs_presenting_ipod.mp4"
},
},
],
}
],
"max_tokens": 300,
}

response = requests.post(url, json=data)
print(response.text)
```

## Important Server Parameters and Flags

When launching the model server for **multimodal support**, you can use the following command-line arguments to fine-tune performance and behavior:

- `--mm-attention-backend`: Specify multimodal attention backend. Eg. `fa3`(Flash Attention 3)
- `--mm-max-concurrent-calls <value>`: Specifies the **maximum number of concurrent asynchronous multimodal data processing calls** allowed on the server. Use this to control parallel throughput and GPU memory usage during image/video inference.
- `--mm-per-request-timeout <seconds>`: Defines the **timeout duration (in seconds)** for each multimodal request. If a request exceeds this time limit (e.g., for very large video inputs), it will be automatically terminated.
- `--keep-mm-feature-on-device`: Instructs the server to **retain multimodal feature tensors on the GPU** after processing. This avoids device-to-host (D2H) memory copies and improves performance for repeated or high-frequency inference workloads.
- `--mm-enable-dp-encoder`: Placing the ViT in data parallel while keeping the LLM in tensor parallel consistently lowers TTFT and boosts end-to-end throughput.
- `SGLANG_USE_CUDA_IPC_TRANSPORT=1`: Shared memory pool based CUDA IPC for multi-modal data transport. For significantly improving e2e latency.

### Example usage with the above optimizations:
```bash
SGLANG_USE_CUDA_IPC_TRANSPORT=1 \
SGLANG_VLM_CACHE_SIZE_MB=0 \
python -m sglang.launch_server \
--model-path zai-org/GLM-4.6V \
--host 0.0.0.0 \
--port 30000 \
--trust-remote-code \
--tp-size 8 \
--enable-cache-report \
--log-level info \
--max-running-requests 64 \
--mem-fraction-static 0.65 \
--chunked-prefill-size 8192 \
--attention-backend fa3 \
--mm-attention-backend fa3 \
--mm-enable-dp-encoder \
--enable-metrics
```

### Thinking Budget for GLM-4.5V / GLM-4.6V

In SGLang, we can implement thinking budget with `CustomLogitProcessor`.

Launch a server with `--enable-custom-logit-processor` flag on. and using `Glm4MoeThinkingBudgetLogitProcessor` in the request likes `GLM-4.6` example in [glm45.md](./glm45.md).
4 changes: 3 additions & 1 deletion docs/basic_usage/popular_model_usage.rst
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
Popular Model Usage (DeepSeek, GPT-OSS, Llama, Qwen, and more)
Popular Model Usage (DeepSeek, GPT-OSS, GLM, Llama, Qwen, and more)
===============================================================

.. toctree::
:maxdepth: 1

deepseek_v3.md
deepseek_v32.md
glm45.md
glmv.md
gpt_oss.md
qwen3.md
qwen3_vl.md
Expand Down
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ dependencies = [
"torch_memory_saver==0.0.9",
"torch==2.9.1",
"torchaudio==2.9.1",
"torchcodec==0.7.0 ; sys_platform != 'linux' or (sys_platform == 'linux' and platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')", # torchcodec does not exist in those systems. If not provided, transformer will use torchvision instead by default.
"torchcodec==0.8.0 ; sys_platform != 'linux' or (sys_platform == 'linux' and platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')", # torchcodec does not exist in those systems. If not provided, transformer will use torchvision instead by default.
"torchvision",
"torchao==0.9.0",
"tqdm",
Expand Down
4 changes: 0 additions & 4 deletions python/sglang/srt/configs/qwen3_omni.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from transformers import PretrainedConfig
from transformers.configuration_utils import layer_type_validation
from transformers.modeling_rope_utils import rope_config_validation

from sglang.utils import logger

Expand Down Expand Up @@ -168,7 +167,6 @@ def __init__(
# BC: if there is a 'type' field, move it to 'rope_type'.
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self)

# MoE arguments
self.decoder_sparse_step = decoder_sparse_step
Expand Down Expand Up @@ -311,7 +309,6 @@ def __init__(
# BC: if there is a 'type' field, move it to 'rope_type'.
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self)

self.layer_types = layer_types
if self.layer_types is None:
Expand Down Expand Up @@ -405,7 +402,6 @@ def __init__(
# BC: if there is a 'type' field, move it to 'rope_type'.
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self)

# MoE arguments
self.decoder_sparse_step = decoder_sparse_step
Expand Down
5 changes: 0 additions & 5 deletions python/sglang/srt/configs/qwen3_vl.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from transformers import PretrainedConfig
from transformers.modeling_rope_utils import rope_config_validation


class Qwen3VLVisionConfig(PretrainedConfig):
Expand Down Expand Up @@ -187,8 +186,6 @@ def __init__(
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout

rope_config_validation(self, ignore_keys={"mrope_section", "mrope_interleaved"})

super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)


Expand Down Expand Up @@ -450,8 +447,6 @@ def __init__(
self.rope_scaling = rope_scaling
self.head_dim = head_dim or hidden_size // num_attention_heads

rope_config_validation(self, ignore_keys={"mrope_section", "mrope_interleaved"})

# MoE arguments
self.decoder_sparse_step = decoder_sparse_step
self.moe_intermediate_size = moe_intermediate_size
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/models/glm4_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,7 @@ def __init__(
if get_global_server_args().disable_shared_experts_fusion
else config.n_shared_experts
)

self.config = config
self.layer_id = layer_id
self.alt_stream = alt_stream
Expand Down
22 changes: 19 additions & 3 deletions python/sglang/srt/models/glm4v.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def __init__(
num_heads: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
attn_qkv_bias: bool = True,
num_dummy_heads: int = 0,
rms_norm_eps: float = 1e-5,
use_data_parallel: bool = False,
Expand All @@ -136,7 +137,8 @@ def __init__(
num_heads=num_heads,
projection_size=dim,
use_qkv_parallel=True,
proj_bias=True,
proj_bias=False,
qkv_bias=attn_qkv_bias,
flatten_batch=True,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
Expand Down Expand Up @@ -440,6 +442,7 @@ def __init__(
quant_config=quant_config,
prefix=add_prefix(f"blocks.{layer_idx}", prefix),
rms_norm_eps=vision_config.rms_norm_eps,
attn_qkv_bias=vision_config.attention_bias,
use_data_parallel=use_data_parallel,
)
for layer_idx in range(depth)
Expand Down Expand Up @@ -623,14 +626,27 @@ def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
self.visual.dtype
)
video_grid_thw = torch.concat([item.video_grid_thw for item in items], dim=0)

# reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames
temp_frames_hw = []
for t, h, w in video_grid_thw:
repeated_row = (
torch.tensor([1, h.item(), w.item()]).unsqueeze(0).repeat(t, 1)
)
temp_frames_hw.append(repeated_row)
flattened_video_grid_thw = torch.cat(temp_frames_hw, dim=0)

assert pixel_values.dim() == 2, pixel_values.dim()
assert video_grid_thw.dim() == 2, video_grid_thw.dim()
if self.use_data_parallel:
return run_dp_sharded_mrope_vision_model(
self.visual, pixel_values, video_grid_thw.tolist(), rope_type="rope_3d"
self.visual,
pixel_values,
flattened_video_grid_thw.tolist(),
rope_type="rope_3d",
)
else:
video_embeds = self.visual(pixel_values, grid_thw=video_grid_thw)
video_embeds = self.visual(pixel_values, grid_thw=flattened_video_grid_thw)
return video_embeds

def get_input_embeddings(self):
Expand Down
Loading
Loading