diff --git a/python/sglang/srt/models/glm4.py b/python/sglang/srt/models/glm4.py index 897a981f1b61..75c4ccc6989c 100644 --- a/python/sglang/srt/models/glm4.py +++ b/python/sglang/srt/models/glm4.py @@ -220,7 +220,9 @@ def __init__( rope_scaling = getattr(config, "rope_scaling", None) max_position_embeddings = getattr(config, "max_position_embeddings", 32768) head_dim = getattr(config, "head_dim", None) - partial_rotary_factor = getattr(config, "partial_rotary_factor", None) + partial_rotary_factor = getattr( + getattr(config, "rope_parameters", None), "partial_rotary_factor", None + ) or getattr(config, "partial_rotary_factor", 0.5) dual_chunk_attention_config = getattr( config, "dual_chunk_attention_config", None ) diff --git a/python/sglang/srt/models/glm4_moe.py b/python/sglang/srt/models/glm4_moe.py index d160c8876330..a9689b8f2754 100644 --- a/python/sglang/srt/models/glm4_moe.py +++ b/python/sglang/srt/models/glm4_moe.py @@ -682,7 +682,9 @@ def __init__( self.config = config rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) - partial_rotary_factor = getattr(config, "partial_rotary_factor", 0.5) + partial_rotary_factor = getattr( + getattr(config, "rope_parameters", None), "partial_rotary_factor", None + ) or getattr(config, "partial_rotary_factor", 0.5) max_position_embeddings = getattr(config, "max_position_embeddings", 8192) head_dim = getattr( config, "head_dim", config.hidden_size // config.num_attention_heads diff --git a/python/sglang/srt/models/glm4v.py b/python/sglang/srt/models/glm4v.py index b1ce0cc71b03..ddce004026fc 100644 --- a/python/sglang/srt/models/glm4v.py +++ b/python/sglang/srt/models/glm4v.py @@ -27,18 +27,24 @@ from einops import rearrange from transformers.models.glm4v.configuration_glm4v import Glm4vConfig, Glm4vVisionConfig +from sglang.srt.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from sglang.srt.distributed.parallel_state import get_pp_group from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.attention import vision_utils from sglang.srt.layers.attention.vision import VisionAttention -from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.layernorm import LayerNorm, RMSNorm from sglang.srt.layers.linear import ( - ColumnParallelLinear, MergedColumnParallelLinear, + ReplicatedLinear, RowParallelLinear, ) from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.utils import PPMissingLayer from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead from sglang.srt.managers.mm_utils import ( MultiModalityDataPaddingPatternMultimodalTokens, @@ -48,6 +54,8 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.glm4 import Glm4Model +from sglang.srt.multimodal.mm_utils import run_dp_sharded_mrope_vision_model +from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import add_prefix from sglang.srt.utils.hf_transformers_utils import get_processor @@ -73,14 +81,21 @@ def __init__( bias: bool = False, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ): super().__init__() + self.tp_size = ( + 1 if use_data_parallel else get_tensor_model_parallel_world_size() + ) + self.tp_rank = 0 if use_data_parallel else get_tensor_model_parallel_rank() self.gate_up_proj = MergedColumnParallelLinear( input_size=in_features, output_sizes=[hidden_features] * 2, # [gate_proj, up_proj] bias=bias, quant_config=quant_config, prefix=add_prefix("gate_up_proj", prefix), + tp_size=self.tp_size, + tp_rank=self.tp_rank, ) self.down_proj = RowParallelLinear( hidden_features, @@ -88,6 +103,8 @@ def __init__( bias=bias, quant_config=quant_config, prefix=add_prefix("down_proj", prefix), + tp_size=self.tp_size, + tp_rank=self.tp_rank, ) self.act_fn = SiluAndMul() @@ -108,6 +125,7 @@ def __init__( prefix: str = "", num_dummy_heads: int = 0, rms_norm_eps: float = 1e-5, + use_data_parallel: bool = False, ) -> None: super().__init__() self.norm1 = RMSNorm(dim, eps=rms_norm_eps) @@ -123,12 +141,14 @@ def __init__( quant_config=quant_config, prefix=add_prefix("attn", prefix), num_dummy_heads=num_dummy_heads, + use_data_parallel=use_data_parallel, ) self.mlp = Glm4vVisionMLP( dim, intermediate_dim, quant_config=quant_config, prefix=add_prefix("mlp", prefix), + use_data_parallel=use_data_parallel, ) def forward( @@ -206,24 +226,28 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, bias: bool = False, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() self.hidden_size = d_model - self.proj = ColumnParallelLinear( + tp_size = 1 if use_data_parallel else get_tensor_model_parallel_world_size() + tp_rank = 0 if use_data_parallel else get_tensor_model_parallel_rank() + self.proj = ReplicatedLinear( self.hidden_size, self.hidden_size, bias=bias, quant_config=quant_config, prefix=add_prefix("proj", prefix), - gather_output=True, ) - self.post_projection_norm = nn.LayerNorm(self.hidden_size) + self.post_projection_norm = LayerNorm(self.hidden_size) self.gate_up_proj = MergedColumnParallelLinear( input_size=self.hidden_size, output_sizes=[context_dim] * 2, bias=bias, quant_config=quant_config, prefix=add_prefix("gate_up_proj", prefix), + tp_size=tp_size, + tp_rank=tp_rank, ) self.down_proj = RowParallelLinear( context_dim, @@ -231,6 +255,8 @@ def __init__( bias=bias, quant_config=quant_config, prefix=add_prefix("down_proj", prefix), + tp_size=tp_size, + tp_rank=tp_rank, ) self.extra_activation_func = nn.GELU() @@ -379,6 +405,7 @@ def __init__( vision_config: Glm4vVisionConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() @@ -392,6 +419,7 @@ def __init__( self.patch_size = vision_config.patch_size self.spatial_merge_size = vision_config.spatial_merge_size self.out_hidden_size = vision_config.out_hidden_size + self.use_data_parallel = use_data_parallel self.patch_embed = Glm4vVisionPatchEmbed( patch_size=patch_size, @@ -412,6 +440,7 @@ def __init__( quant_config=quant_config, prefix=add_prefix(f"blocks.{layer_idx}", prefix), rms_norm_eps=vision_config.rms_norm_eps, + use_data_parallel=use_data_parallel, ) for layer_idx in range(depth) ] @@ -423,6 +452,7 @@ def __init__( quant_config=quant_config, bias=False, prefix=add_prefix("merger", prefix), + use_data_parallel=use_data_parallel, ) self.embeddings = Glm4vVisionEmbeddings(vision_config) @@ -527,11 +557,14 @@ def __init__( ) -> None: super().__init__() + self.pp_group = get_pp_group() self.config = config + self.use_data_parallel = get_global_server_args().mm_enable_dp_encoder self.visual = Glm4vVisionModel( config.vision_config, quant_config=quant_config, prefix=add_prefix("visual", prefix), + use_data_parallel=self.use_data_parallel, ) vision_utils.update_vit_attn_dummy_heads_config(self.config) @@ -542,15 +575,19 @@ def __init__( prefix=add_prefix("model", prefix), ) - if config.tie_word_embeddings: - self.lm_head = self.model.embed_tokens + if self.pp_group.is_last_rank: + if self.pp_group.world_size == 1 and self.config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead( + self.config.vocab_size, + self.config.hidden_size, + quant_config=quant_config, + prefix=add_prefix("lm_head", prefix), + ) else: - self.lm_head = ParallelLMHead( - config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=add_prefix("lm_head", prefix), - ) + # ranks other than the last rank will have a placeholder layer + self.lm_head = PPMissingLayer() self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling @@ -565,45 +602,36 @@ def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs): return pattern.pad_input_tokens(input_ids, mm_inputs) def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: - pixel_values = torch.cat( - [item.feature.squeeze(0) for item in items], dim=0 - ).type(self.visual.dtype) + # in GLM-V, last dim is the same + pixel_values = torch.cat([item.feature for item in items], dim=0).type( + self.visual.dtype + ) image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0) - # For multi-image, pixel_values is [num_of_images, L, C] shape - # assert pixel_values.dim() == 2, pixel_values.dim() + assert pixel_values.dim() == 2, pixel_values.dim() assert image_grid_thw.dim() == 2, image_grid_thw.dim() - image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) - split_sizes = ( - image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2 - ).tolist() - image_embeds = torch.split(image_embeds, split_sizes) - return torch.cat(image_embeds) + if self.use_data_parallel: + return run_dp_sharded_mrope_vision_model( + self.visual, pixel_values, image_grid_thw.tolist(), rope_type="rope_3d" + ) + else: + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + return image_embeds def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: - pixel_values_videos = torch.cat( - [item.feature.squeeze(0) for item in items], dim=0 - ).type(self.visual.dtype) + # in GLM-V, last dim is the same + pixel_values = torch.cat([item.feature for item in items], dim=0).type( + self.visual.dtype + ) video_grid_thw = torch.concat([item.video_grid_thw for item in items], dim=0) - # For multi-video, pixel_values_videos is [num_of_videos, L, C] shape - # assert pixel_values_videos.dim() == 2, pixel_values_videos.dim() + assert pixel_values.dim() == 2, pixel_values.dim() assert video_grid_thw.dim() == 2, video_grid_thw.dim() - - # reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames - temp_frames_hw = [] - for t, h, w in video_grid_thw: - repeated_row = ( - torch.tensor([1, h.item(), w.item()]).unsqueeze(0).repeat(t, 1) + if self.use_data_parallel: + return run_dp_sharded_mrope_vision_model( + self.visual, pixel_values, video_grid_thw.tolist(), rope_type="rope_3d" ) - temp_frames_hw.append(repeated_row) - flattened_video_grid_thw = torch.cat(temp_frames_hw, dim=0) - video_embeds = self.visual( - pixel_values_videos, grid_thw=flattened_video_grid_thw - ) - split_sizes = ( - video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2 - ).tolist() - video_embeds = torch.split(video_embeds, split_sizes) - return torch.cat(video_embeds) + else: + video_embeds = self.visual(pixel_values, grid_thw=video_grid_thw) + return video_embeds def get_input_embeddings(self): return self.model.embed_tokens @@ -653,12 +681,18 @@ def forward( if self.capture_aux_hidden_states: hidden_states, aux_hidden_states = hidden_states - if not get_embedding: - return self.logits_processor( - input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states - ) + if self.pp_group.is_last_rank: + if not get_embedding: + return self.logits_processor( + input_ids, + hidden_states, + self.lm_head, + forward_batch, + ) + else: + return self.pooler(hidden_states, forward_batch) else: - return self.pooler(hidden_states, forward_batch) + return hidden_states def _pad_vit_attn_dummy_heads(self, name: str, loaded_weight: torch.Tensor): """pad attn qkv weights for dummy heads""" diff --git a/test/nightly/test_encoder_dp.py b/test/nightly/test_encoder_dp.py index 5542204a16b9..a18075f71e7c 100644 --- a/test/nightly/test_encoder_dp.py +++ b/test/nightly/test_encoder_dp.py @@ -24,6 +24,7 @@ SimpleNamespace(model="Qwen/Qwen2.5-VL-72B-Instruct", mmmu_accuracy=0.55), SimpleNamespace(model="Qwen/Qwen3-VL-32B-Instruct", mmmu_accuracy=0.55), SimpleNamespace(model="OpenGVLab/InternVL2_5-8B", mmmu_accuracy=0.52), + SimpleNamespace(model="zai-org/GLM-4.1V-9B-Thinking", mmmu_accuracy=0.68), ]