Skip to content
Merged
4 changes: 3 additions & 1 deletion python/sglang/srt/models/glm4.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,9 @@ def __init__(
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", 32768)
head_dim = getattr(config, "head_dim", None)
partial_rotary_factor = getattr(config, "partial_rotary_factor", None)
partial_rotary_factor = getattr(
getattr(config, "rope_parameters", None), "partial_rotary_factor", None
) or getattr(config, "partial_rotary_factor", 0.5)
dual_chunk_attention_config = getattr(
config, "dual_chunk_attention_config", None
)
Expand Down
4 changes: 3 additions & 1 deletion python/sglang/srt/models/glm4_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,7 +682,9 @@ def __init__(
self.config = config
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
partial_rotary_factor = getattr(config, "partial_rotary_factor", 0.5)
partial_rotary_factor = getattr(
getattr(config, "rope_parameters", None), "partial_rotary_factor", None
) or getattr(config, "partial_rotary_factor", 0.5)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
head_dim = getattr(
config, "head_dim", config.hidden_size // config.num_attention_heads
Expand Down
134 changes: 84 additions & 50 deletions python/sglang/srt/models/glm4v.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,24 @@
from einops import rearrange
from transformers.models.glm4v.configuration_glm4v import Glm4vConfig, Glm4vVisionConfig

from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from sglang.srt.distributed.parallel_state import get_pp_group
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.attention import vision_utils
from sglang.srt.layers.attention.vision import VisionAttention
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.layernorm import LayerNorm, RMSNorm
from sglang.srt.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.utils import PPMissingLayer
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternMultimodalTokens,
Expand All @@ -48,6 +54,8 @@
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.glm4 import Glm4Model
from sglang.srt.multimodal.mm_utils import run_dp_sharded_mrope_vision_model
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import add_prefix
from sglang.srt.utils.hf_transformers_utils import get_processor

Expand All @@ -73,21 +81,30 @@ def __init__(
bias: bool = False,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
self.tp_size = (
1 if use_data_parallel else get_tensor_model_parallel_world_size()
)
self.tp_rank = 0 if use_data_parallel else get_tensor_model_parallel_rank()
self.gate_up_proj = MergedColumnParallelLinear(
input_size=in_features,
output_sizes=[hidden_features] * 2, # [gate_proj, up_proj]
bias=bias,
quant_config=quant_config,
prefix=add_prefix("gate_up_proj", prefix),
tp_size=self.tp_size,
tp_rank=self.tp_rank,
)
self.down_proj = RowParallelLinear(
hidden_features,
in_features,
bias=bias,
quant_config=quant_config,
prefix=add_prefix("down_proj", prefix),
tp_size=self.tp_size,
tp_rank=self.tp_rank,
)
self.act_fn = SiluAndMul()

Expand All @@ -108,6 +125,7 @@ def __init__(
prefix: str = "",
num_dummy_heads: int = 0,
rms_norm_eps: float = 1e-5,
use_data_parallel: bool = False,
) -> None:
super().__init__()
self.norm1 = RMSNorm(dim, eps=rms_norm_eps)
Expand All @@ -123,12 +141,14 @@ def __init__(
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
num_dummy_heads=num_dummy_heads,
use_data_parallel=use_data_parallel,
)
self.mlp = Glm4vVisionMLP(
dim,
intermediate_dim,
quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
use_data_parallel=use_data_parallel,
)

def forward(
Expand Down Expand Up @@ -206,31 +226,37 @@ def __init__(
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()
self.hidden_size = d_model
self.proj = ColumnParallelLinear(
tp_size = 1 if use_data_parallel else get_tensor_model_parallel_world_size()
tp_rank = 0 if use_data_parallel else get_tensor_model_parallel_rank()
self.proj = ReplicatedLinear(
self.hidden_size,
self.hidden_size,
bias=bias,
quant_config=quant_config,
prefix=add_prefix("proj", prefix),
gather_output=True,
)
self.post_projection_norm = nn.LayerNorm(self.hidden_size)
self.post_projection_norm = LayerNorm(self.hidden_size)
self.gate_up_proj = MergedColumnParallelLinear(
input_size=self.hidden_size,
output_sizes=[context_dim] * 2,
bias=bias,
quant_config=quant_config,
prefix=add_prefix("gate_up_proj", prefix),
tp_size=tp_size,
tp_rank=tp_rank,
)
self.down_proj = RowParallelLinear(
context_dim,
self.hidden_size,
bias=bias,
quant_config=quant_config,
prefix=add_prefix("down_proj", prefix),
tp_size=tp_size,
tp_rank=tp_rank,
)
self.extra_activation_func = nn.GELU()

Expand Down Expand Up @@ -379,6 +405,7 @@ def __init__(
vision_config: Glm4vVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()

Expand All @@ -392,6 +419,7 @@ def __init__(
self.patch_size = vision_config.patch_size
self.spatial_merge_size = vision_config.spatial_merge_size
self.out_hidden_size = vision_config.out_hidden_size
self.use_data_parallel = use_data_parallel

self.patch_embed = Glm4vVisionPatchEmbed(
patch_size=patch_size,
Expand All @@ -412,6 +440,7 @@ def __init__(
quant_config=quant_config,
prefix=add_prefix(f"blocks.{layer_idx}", prefix),
rms_norm_eps=vision_config.rms_norm_eps,
use_data_parallel=use_data_parallel,
)
for layer_idx in range(depth)
]
Expand All @@ -423,6 +452,7 @@ def __init__(
quant_config=quant_config,
bias=False,
prefix=add_prefix("merger", prefix),
use_data_parallel=use_data_parallel,
)

self.embeddings = Glm4vVisionEmbeddings(vision_config)
Expand Down Expand Up @@ -527,11 +557,14 @@ def __init__(
) -> None:
super().__init__()

self.pp_group = get_pp_group()
self.config = config
self.use_data_parallel = get_global_server_args().mm_enable_dp_encoder
self.visual = Glm4vVisionModel(
config.vision_config,
quant_config=quant_config,
prefix=add_prefix("visual", prefix),
use_data_parallel=self.use_data_parallel,
)

vision_utils.update_vit_attn_dummy_heads_config(self.config)
Expand All @@ -542,15 +575,19 @@ def __init__(
prefix=add_prefix("model", prefix),
)

if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
if self.pp_group.is_last_rank:
if self.pp_group.world_size == 1 and self.config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(
self.config.vocab_size,
self.config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
)
else:
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
)
# ranks other than the last rank will have a placeholder layer
self.lm_head = PPMissingLayer()

self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling

Expand All @@ -565,45 +602,36 @@ def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
return pattern.pad_input_tokens(input_ids, mm_inputs)

def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
pixel_values = torch.cat(
[item.feature.squeeze(0) for item in items], dim=0
).type(self.visual.dtype)
# in GLM-V, last dim is the same
pixel_values = torch.cat([item.feature for item in items], dim=0).type(
self.visual.dtype
)
image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0)
# For multi-image, pixel_values is [num_of_images, L, C] shape
# assert pixel_values.dim() == 2, pixel_values.dim()
assert pixel_values.dim() == 2, pixel_values.dim()
assert image_grid_thw.dim() == 2, image_grid_thw.dim()
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
split_sizes = (
image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2
).tolist()
image_embeds = torch.split(image_embeds, split_sizes)
return torch.cat(image_embeds)
if self.use_data_parallel:
return run_dp_sharded_mrope_vision_model(
self.visual, pixel_values, image_grid_thw.tolist(), rope_type="rope_3d"
)
else:
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
return image_embeds

def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
pixel_values_videos = torch.cat(
[item.feature.squeeze(0) for item in items], dim=0
).type(self.visual.dtype)
# in GLM-V, last dim is the same
pixel_values = torch.cat([item.feature for item in items], dim=0).type(
self.visual.dtype
)
video_grid_thw = torch.concat([item.video_grid_thw for item in items], dim=0)
# For multi-video, pixel_values_videos is [num_of_videos, L, C] shape
# assert pixel_values_videos.dim() == 2, pixel_values_videos.dim()
assert pixel_values.dim() == 2, pixel_values.dim()
assert video_grid_thw.dim() == 2, video_grid_thw.dim()

# reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames
temp_frames_hw = []
for t, h, w in video_grid_thw:
repeated_row = (
torch.tensor([1, h.item(), w.item()]).unsqueeze(0).repeat(t, 1)
if self.use_data_parallel:
return run_dp_sharded_mrope_vision_model(
self.visual, pixel_values, video_grid_thw.tolist(), rope_type="rope_3d"
)
temp_frames_hw.append(repeated_row)
flattened_video_grid_thw = torch.cat(temp_frames_hw, dim=0)
video_embeds = self.visual(
pixel_values_videos, grid_thw=flattened_video_grid_thw
)
split_sizes = (
video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2
).tolist()
video_embeds = torch.split(video_embeds, split_sizes)
return torch.cat(video_embeds)
else:
video_embeds = self.visual(pixel_values, grid_thw=video_grid_thw)
return video_embeds
Comment on lines 567 to +634
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic in get_image_feature and get_video_feature is almost identical. This code duplication can be avoided by refactoring it into a single helper method. This would improve maintainability and reduce redundancy.


def get_input_embeddings(self):
return self.model.embed_tokens
Expand Down Expand Up @@ -653,12 +681,18 @@ def forward(
if self.capture_aux_hidden_states:
hidden_states, aux_hidden_states = hidden_states

if not get_embedding:
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states
)
if self.pp_group.is_last_rank:
if not get_embedding:
return self.logits_processor(
input_ids,
hidden_states,
self.lm_head,
forward_batch,
)
else:
return self.pooler(hidden_states, forward_batch)
else:
return self.pooler(hidden_states, forward_batch)
return hidden_states

def _pad_vit_attn_dummy_heads(self, name: str, loaded_weight: torch.Tensor):
"""pad attn qkv weights for dummy heads"""
Expand Down
1 change: 1 addition & 0 deletions test/nightly/test_encoder_dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
SimpleNamespace(model="Qwen/Qwen2.5-VL-72B-Instruct", mmmu_accuracy=0.55),
SimpleNamespace(model="Qwen/Qwen3-VL-32B-Instruct", mmmu_accuracy=0.55),
SimpleNamespace(model="OpenGVLab/InternVL2_5-8B", mmmu_accuracy=0.52),
SimpleNamespace(model="zai-org/GLM-4.1V-9B-Thinking", mmmu_accuracy=0.68),
]


Expand Down
Loading