diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index 49f9464eda6b..4b3856fd14b6 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -1070,6 +1070,7 @@ def _triton_mrope_forward( mrope_section_h: tl.constexpr, mrope_section_w: tl.constexpr, is_interleaved: tl.constexpr, + is_neox_style: tl.constexpr, ): # Adapted from # https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/qwen2vl_mrope.py @@ -1124,51 +1125,99 @@ def _triton_mrope_forward( # program instance (i.e. for the current token) separately # #################################################################### # left half of the head - first_half_q_offsets = ( - tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :] - ) - first_half_k_offsets = ( - tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :] - ) - first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & ( - tl.arange(0, pad_hd // 2)[None, :] < rd // 2 - ) - first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & ( - tl.arange(0, pad_hd // 2)[None, :] < rd // 2 - ) + if is_neox_style: + first_half_q_offsets = ( + tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :] + ) + first_half_k_offsets = ( + tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :] + ) + first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & ( + tl.arange(0, pad_hd // 2)[None, :] < rd // 2 + ) + first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & ( + tl.arange(0, pad_hd // 2)[None, :] < rd // 2 + ) - q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to( - sin_row.dtype - ) - k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to( - sin_row.dtype - ) + q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to( + sin_row.dtype + ) + k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to( + sin_row.dtype + ) - # right half of the head - second_half_q_offsets = first_half_q_offsets + (rd // 2) - second_half_k_offsets = first_half_k_offsets + (rd // 2) - second_q_mask = first_q_mask - second_k_mask = first_k_mask + # right half of the head + second_half_q_offsets = first_half_q_offsets + (rd // 2) + second_half_k_offsets = first_half_k_offsets + (rd // 2) + second_q_mask = first_q_mask + second_k_mask = first_k_mask + + q_tile_2 = tl.load( + q_ptr + second_half_q_offsets, mask=second_q_mask, other=0 + ).to(sin_row.dtype) + k_tile_2 = tl.load( + k_ptr + second_half_k_offsets, mask=second_k_mask, other=0 + ).to(sin_row.dtype) + + # y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin] + # Since cos and sin are now half-size, + # we use the same cos_row and sin_row for both halves + new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row + tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask) + new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row + tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask) + + new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row + tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask) + new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row + tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask) + else: + base_q = tl.arange(0, pad_n_qh)[:, None] * hd + base_k = tl.arange(0, pad_n_kh)[:, None] * hd + even_idx = 2 * tl.arange(0, pad_hd // 2)[None, :] + odd_idx = even_idx + 1 + + even_q_offsets = base_q + even_idx + odd_q_offsets = base_q + odd_idx + even_k_offsets = base_k + even_idx + odd_k_offsets = base_k + odd_idx + + idx_mask = tl.arange(0, pad_hd // 2)[None, :] < (rd // 2) + qn_mask = tl.arange(0, pad_n_qh)[:, None] < n_qh + kn_mask = tl.arange(0, pad_n_kh)[:, None] < n_kh + + even_q_mask = qn_mask & idx_mask + odd_q_mask = qn_mask & idx_mask + even_k_mask = kn_mask & idx_mask + odd_k_mask = kn_mask & idx_mask + + q_tile_1 = tl.load(q_ptr + even_q_offsets, mask=even_q_mask, other=0).to( + sin_row.dtype + ) + k_tile_1 = tl.load(k_ptr + even_k_offsets, mask=even_k_mask, other=0).to( + sin_row.dtype + ) - q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to( - sin_row.dtype - ) - k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to( - sin_row.dtype - ) + q_tile_2 = tl.load(q_ptr + odd_q_offsets, mask=odd_q_mask, other=0).to( + sin_row.dtype + ) + k_tile_2 = tl.load(k_ptr + odd_k_offsets, mask=odd_k_mask, other=0).to( + sin_row.dtype + ) - # y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin] - # Since cos and sin are now half-size, - # we use the same cos_row and sin_row for both halves - new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row - tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask) - new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row - tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask) + # y = [x_even, x_odd] * [cos, cos] + [-x_odd, x_even] * [sin, sin] + # NeoX-style rotary embedding: + # Each (even, odd) channel pair forms one rotation arm. + # cos_row and sin_row each have length rd//2, shared across all (even, odd) pairs. + new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row + tl.store(q_ptr + even_q_offsets, new_q_tile_1, mask=even_q_mask) + new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row + tl.store(q_ptr + odd_q_offsets, new_q_tile_2, mask=odd_q_mask) - new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row - tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask) - new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row - tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask) + new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row + tl.store(k_ptr + even_k_offsets, new_k_tile_1, mask=even_k_mask) + new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row + tl.store(k_ptr + odd_k_offsets, new_k_tile_2, mask=odd_k_mask) def triton_mrope( @@ -1180,6 +1229,7 @@ def triton_mrope( head_size: int, rotary_dim: int, mrope_interleaved: bool, + is_neox_style: bool, ) -> tuple[torch.Tensor, torch.Tensor]: """The mrope triton kernel. @@ -1230,6 +1280,7 @@ def triton_mrope( mrope_section[1], mrope_section[2], mrope_interleaved, + is_neox_style, ) return q, k @@ -1400,6 +1451,7 @@ def _forward_triton( self.head_size, self.rotary_dim, self.mrope_interleaved, + self.is_neox_style, ) return q.reshape(query_shape), k.reshape(key_shape) diff --git a/python/sglang/srt/models/glm4.py b/python/sglang/srt/models/glm4.py index 7357e5f82565..897a981f1b61 100644 --- a/python/sglang/srt/models/glm4.py +++ b/python/sglang/srt/models/glm4.py @@ -15,46 +15,119 @@ # Modeling from: # ./llama.py and # https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4/modular_glm4.py -"""Inference-only GLM4 model compatible with THUDM weights.""" +"""Inference-only GLM-4-0414 model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Tuple, Union +import logging +from typing import Any, Dict, Iterable, Optional, Tuple, Union import torch from torch import nn -from transformers import Glm4Config -from sglang.srt.distributed import get_tensor_model_parallel_world_size +from sglang.srt.distributed import ( + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.dp_attention import is_dp_attention_enabled from sglang.srt.layers.layernorm import RMSNorm -from sglang.srt.layers.linear import QKVParallelLinear, RowParallelLinear +from sglang.srt.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.rotary_embedding import get_rope +from sglang.srt.layers.utils import PPMissingLayer, get_layer_id from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.model_executor.forward_batch_info import ForwardBatch -from sglang.srt.model_loader.weight_utils import default_weight_loader -from sglang.srt.models.llama import LlamaMLP as Glm4MLP +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors +from sglang.srt.model_loader.weight_utils import ( + default_weight_loader, + kv_cache_scales_loader, +) from sglang.srt.utils import add_prefix, make_layers +Glm4Config = None + +logger = logging.getLogger(__name__) + + +class Glm4MLP(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + reduce_results: bool = True, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=add_prefix("gate_up_proj", prefix), + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=add_prefix("down_proj", prefix), + reduce_results=reduce_results, + ) + if hidden_act != "silu": + raise ValueError( + f"Unsupported activation: {hidden_act}. Only silu is supported for now." + ) + self.act_fn = SiluAndMul() + + def forward( + self, + x, + forward_batch=None, + use_reduce_scatter: bool = False, + ): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj( + x, + skip_all_reduce=use_reduce_scatter, + ) + return x + class Glm4Attention(nn.Module): def __init__( self, - config, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: Optional[int] = None, layer_id: int = 0, + rope_theta: float = 1000000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 131072, quant_config: Optional[QuantizationConfig] = None, + dual_chunk_attention_config: Optional[dict[str, Any]] = None, + partial_rotary_factor: float = 0.5, prefix: str = "", - ): + ) -> None: super().__init__() - self.hidden_size = config.hidden_size + self.hidden_size = hidden_size tp_size = get_tensor_model_parallel_world_size() - self.total_num_heads = config.num_attention_heads + self.total_num_heads = num_heads assert self.total_num_heads % tp_size == 0 self.num_heads = self.total_num_heads // tp_size - self.total_num_kv_heads = config.num_key_value_heads + self.total_num_kv_heads = num_kv_heads if self.total_num_kv_heads >= tp_size: # Number of KV heads is greater than TP size, so we partition # the KV heads across multiple tensor parallel GPUs. @@ -63,27 +136,30 @@ def __init__( # Number of KV heads is less than TP size, so we replicate # the KV heads across multiple tensor parallel GPUs. assert tp_size % self.total_num_kv_heads == 0 - partial_rotary_factor = getattr(config, "partial_rotary_factor", 0.5) self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) - self.head_dim = config.hidden_size // self.total_num_heads + if head_dim is not None: + self.head_dim = head_dim + else: + self.head_dim = hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 - self.rope_theta = getattr(config, "rope_theta", 1000000) - self.rope_scaling = getattr(config, "rope_scaling", None) + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + self.partial_rotary_factor = partial_rotary_factor self.qkv_proj = QKVParallelLinear( - self.hidden_size, + hidden_size, self.head_dim, self.total_num_heads, self.total_num_kv_heads, - bias=config.attention_bias, + bias=True, quant_config=quant_config, prefix=add_prefix("qkv_proj", prefix), ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, - self.hidden_size, + hidden_size, bias=False, quant_config=quant_config, prefix=add_prefix("o_proj", prefix), @@ -92,9 +168,10 @@ def __init__( self.rotary_emb = get_rope( self.head_dim, rotary_dim=self.head_dim, - max_position=config.max_position_embeddings, - base=self.rope_theta, - rope_scaling=self.rope_scaling, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + dual_chunk_attention_config=dual_chunk_attention_config, partial_rotary_factor=partial_rotary_factor, is_neox_style=False, ) @@ -117,14 +194,9 @@ def forward( qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - context_layer = self.attn( - q, - k, - v, - forward_batch, - ) - attn_output, _ = self.o_proj(context_layer) - return attn_output + attn_output = self.attn(q, k, v, forward_batch) + output, _ = self.o_proj(attn_output) + return output class Glm4DecoderLayer(nn.Module): @@ -136,15 +208,35 @@ class Glm4DecoderLayer(nn.Module): def __init__( self, - config, - layer_id: int, + config: Glm4Config, + layer_id: int = 0, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", - ): + alt_stream: Optional[torch.cuda.Stream] = None, + ) -> None: super().__init__() - # Self attention. + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 1000000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", 32768) + head_dim = getattr(config, "head_dim", None) + partial_rotary_factor = getattr(config, "partial_rotary_factor", None) + dual_chunk_attention_config = getattr( + config, "dual_chunk_attention_config", None + ) self.self_attn = Glm4Attention( - config, layer_id, quant_config, prefix=add_prefix("self_attn", prefix) + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + head_dim=head_dim, + layer_id=layer_id, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + dual_chunk_attention_config=dual_chunk_attention_config, + partial_rotary_factor=partial_rotary_factor, + prefix=add_prefix("self_attn", prefix), ) # MLP @@ -199,54 +291,125 @@ def __init__( config: Glm4Config, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + decoder_layer_type: type[nn.Module] = Glm4DecoderLayer, + alt_stream: Optional[torch.cuda.Stream] = None, ) -> None: super().__init__() self.config = config - self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=add_prefix("embed_tokens", prefix), - ) - self.layers = make_layers( + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.pp_group = get_pp_group() + + if self.pp_group.is_first_rank: + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + enable_tp=not is_dp_attention_enabled(), + prefix=add_prefix("embed_tokens", prefix), + ) + else: + self.embed_tokens = PPMissingLayer() + + # Use the provided decoder layer type or default to Glm4DecoderLayer + decoder_layer_type = decoder_layer_type or Glm4DecoderLayer + self.layers, self.start_layer, self.end_layer = make_layers( config.num_hidden_layers, - lambda idx, prefix: Glm4DecoderLayer( - config=config, layer_id=idx, quant_config=quant_config, prefix=prefix + lambda idx, prefix: decoder_layer_type( + layer_id=idx, + config=config, + quant_config=quant_config, + prefix=prefix, + alt_stream=alt_stream, ), - prefix="model.layers", + pp_rank=self.pp_group.rank_in_group, + pp_size=self.pp_group.world_size, + prefix=add_prefix("layers", prefix), ) + if self.pp_group.is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer(return_tuple=True) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + # For EAGLE3 support + self.layers_to_capture = [] def get_input_embeddings(self) -> nn.Embedding: return self.embed_tokens - def dtype(self) -> torch.dtype: - return next(self.parameters()).dtype - - @torch.no_grad() def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: - if input_embeds is None: - hidden_states = self.embed_tokens(input_ids) + pp_proxy_tensors: Optional[PPProxyTensors] = None, + ) -> Union[torch.Tensor, PPProxyTensors]: + if self.pp_group.is_first_rank: + if input_embeds is None: + hidden_states = self.embed_tokens(input_ids) + else: + hidden_states = input_embeds + residual = None else: - hidden_states = input_embeds - residual = None - for layer in self.layers: + assert pp_proxy_tensors is not None + hidden_states = pp_proxy_tensors["hidden_states"] + residual = pp_proxy_tensors["residual"] + + aux_hidden_states = [] + for i in range(self.start_layer, self.end_layer): + if i in self.layers_to_capture: + aux_hidden_states.append( + hidden_states + residual if residual is not None else hidden_states + ) + layer = self.layers[i] hidden_states, residual = layer( positions, hidden_states, forward_batch, residual, ) - hidden_states, _ = self.norm(hidden_states, residual) + if not self.pp_group.is_last_rank: + return PPProxyTensors( + { + "hidden_states": hidden_states, + "residual": residual, + } + ) + else: + if hidden_states.shape[0] != 0: + if residual is None: + hidden_states = self.norm(hidden_states) + else: + hidden_states, _ = self.norm(hidden_states, residual) + + if len(aux_hidden_states) == 0: + return hidden_states - return hidden_states + return hidden_states, aux_hidden_states + + # If this function is called, it should always initialize KV cache scale + # factors (or else raise an exception). Thus, handled exceptions should + # make sure to leave KV cache scale factors in a known good (dummy) state + def load_kv_cache_scales(self, quantization_param_path: str) -> None: + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + for layer_idx, scaling_factor in kv_cache_scales_loader( + quantization_param_path, + tp_rank, + tp_size, + self.config.num_hidden_layers, + self.config.__class__.model_type, + ): + if not isinstance(self.layers[layer_idx], nn.Identity): + layer_self_attn = self.layers[layer_idx].self_attn + if hasattr(layer_self_attn.attn, "k_scale"): + layer_self_attn.attn.k_scale = scaling_factor + layer_self_attn.attn.v_scale = scaling_factor + else: + raise RuntimeError( + "Self attention has no KV cache scaling factor attribute!" + ) class Glm4ForCausalLM(nn.Module): @@ -255,21 +418,54 @@ def __init__( config: Glm4Config, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", - ): + ) -> None: super().__init__() - self.config: Glm4Config = config + self.pp_group = get_pp_group() + self.config = config self.quant_config = quant_config - self.model = Glm4Model(config, quant_config, add_prefix("model", prefix)) - if config.tie_word_embeddings: - self.lm_head = self.model.embed_tokens + self.model = Glm4Model( + config, quant_config=quant_config, prefix=add_prefix("model", prefix) + ) + + # handle the lm head on different pp ranks + if self.pp_group.is_last_rank: + if self.pp_group.world_size == 1 and config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=add_prefix("lm_head", prefix), + ) else: - self.lm_head = ParallelLMHead( - config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix="lm_head", - ) + # ranks other than the last rank will have a placeholder layer + self.lm_head = PPMissingLayer() + + # perform weight tying for PP + if self.pp_group.world_size > 1 and config.tie_word_embeddings: + if self.pp_group.is_first_rank: + self.pp_group.send( + self.model.embed_tokens.weight, dst=self.pp_group.last_rank + ) + else: + emb_token_weight = self.pp_group.recv( + size=(config.vocab_size, config.hidden_size), + dtype=next(self.model.parameters()).dtype, + src=self.pp_group.first_rank, + ) + self.lm_head.weight.copy_(emb_token_weight) + self.logits_processor = LogitsProcessor(config) + self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + # For EAGLE3 support + self.capture_aux_hidden_states = False + + def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embedding(input_ids) + + def get_input_embeddings(self) -> nn.Embedding: + return self.model.embed_tokens @torch.no_grad() def forward( @@ -277,34 +473,138 @@ def forward( input_ids: torch.Tensor, positions: torch.Tensor, forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + get_embedding: bool = False, + pp_proxy_tensors: Optional[PPProxyTensors] = None, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, forward_batch) - return self.logits_processor( - input_ids, hidden_states, self.lm_head, forward_batch + hidden_states = self.model( + input_ids, + positions, + forward_batch, + input_embeds, + pp_proxy_tensors=pp_proxy_tensors, ) + aux_hidden_states = None + if self.capture_aux_hidden_states: + hidden_states, aux_hidden_states = hidden_states + + if self.pp_group.is_last_rank: + if not get_embedding: + return self.logits_processor( + input_ids, + hidden_states, + self.lm_head, + forward_batch, + aux_hidden_states, + ) + else: + return self.pooler(hidden_states, forward_batch) + else: + return hidden_states + + @torch.no_grad() + def forward_split_prefill( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + split_interval: Tuple[int, int], # [start, end) 0-based + input_embeds: torch.Tensor = None, + ): + start, end = split_interval + # embed + if start == 0: + if input_embeds is None: + forward_batch.hidden_states = self.model.embed_tokens(input_ids) + else: + forward_batch.hidden_states = input_embeds + # decoder layer + for i in range(start, end): + layer = self.model.layers[i] + forward_batch.hidden_states, forward_batch.residual = layer( + positions, + forward_batch.hidden_states, + forward_batch, + forward_batch.residual, + ) + + if end == self.model.config.num_hidden_layers: + # norm + hidden_states, _ = self.model.norm( + forward_batch.hidden_states, forward_batch.residual + ) + forward_batch.hidden_states = hidden_states + # logits process + result = self.logits_processor( + input_ids, forward_batch.hidden_states, self.lm_head, forward_batch + ) + else: + result = None + + return result + + @property + def start_layer(self): + return self.model.start_layer + + @property + def end_layer(self): + return self.model.end_layer def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ - # (param_name, weight_name, shard_id) + # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), (".qkv_proj", ".k_proj", "k"), (".qkv_proj", ".v_proj", "v"), - (".gate_up_proj", ".gate_proj", 0), (".gate_up_proj", ".up_proj", 1), + (".gate_up_proj", ".gate_proj", 0), ] + params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: - if self.config.tie_word_embeddings and "lm_head.weight" in name: + layer_id = get_layer_id(name) + if ( + layer_id is not None + and hasattr(self.model, "start_layer") + and ( + layer_id < self.model.start_layer + or layer_id >= self.model.end_layer + ) + ): + continue + + if "rotary_emb.inv_freq" in name or "projector" in name: continue + if self.config.tie_word_embeddings and "lm_head.weight" in name: + if self.pp_group.world_size > 1 and self.pp_group.is_last_rank: + # Handle pp weight tying here + # find the embed_tokens.weight in the weights + embed_token_weights = next( + filter(lambda x: x[0] == "model.embed_tokens.weight", weights) + )[1] + loaded_weight = embed_token_weights + else: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if name not in params_dict: + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if name in params_dict.keys(): param = params_dict[name] weight_loader = getattr( @@ -312,7 +612,21 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ) weight_loader(param, loaded_weight) else: - raise KeyError(f"Parameter '{name}' not found in model.") + logger.warning(f"Parameter {name} not found in params_dict") + + def get_embed_and_head(self): + return self.model.embed_tokens.weight, self.lm_head.weight + + def set_embed_and_head(self, embed, head): + del self.model.embed_tokens.weight + del self.lm_head.weight + self.model.embed_tokens.weight = embed + self.lm_head.weight = head + torch.cuda.empty_cache() + torch.cuda.synchronize() + + def load_kv_cache_scales(self, quantization_param_path: str) -> None: + self.model.load_kv_cache_scales(quantization_param_path) EntryClass = [Glm4ForCausalLM] diff --git a/python/sglang/srt/models/glm4v.py b/python/sglang/srt/models/glm4v.py index d7b0e9618e88..515018c45d8f 100644 --- a/python/sglang/srt/models/glm4v.py +++ b/python/sglang/srt/models/glm4v.py @@ -1,15 +1,35 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# Modeling from: +# ./llama.py and +# https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4v/modular_glm4v.py +"""Inference-only GLM-4.1V model compatible with HuggingFace weights.""" + import logging -from functools import lru_cache, partial +from functools import lru_cache from typing import Iterable, List, Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F +from einops import rearrange from transformers.models.glm4v.configuration_glm4v import Glm4vConfig, Glm4vVisionConfig from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.attention import vision_utils -from sglang.srt.layers.dp_attention import get_attention_tp_size +from sglang.srt.layers.attention.vision import VisionAttention from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( ColumnParallelLinear, @@ -20,13 +40,14 @@ from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead -from sglang.srt.managers.schedule_batch import MultimodalDataItem +from sglang.srt.managers.mm_utils import ( + MultiModalityDataPaddingPatternMultimodalTokens, + general_mm_embed_routine, +) +from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs +from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.glm4 import Glm4Model -from sglang.srt.models.qwen2_5_vl import ( - Qwen2_5_VisionBlock, - Qwen2_5_VLForConditionalGeneration, -) from sglang.srt.utils import add_prefix from sglang.srt.utils.hf_transformers_utils import get_processor @@ -56,7 +77,7 @@ def __init__( super().__init__() self.gate_up_proj = MergedColumnParallelLinear( input_size=in_features, - output_sizes=[hidden_features] * 2, + output_sizes=[hidden_features] * 2, # [gate_proj, up_proj] bias=bias, quant_config=quant_config, prefix=add_prefix("gate_up_proj", prefix), @@ -77,34 +98,95 @@ def forward(self, x: torch.Tensor): return x -class Glm4vVisionBlock(Qwen2_5_VisionBlock): +class Glm4vVisionBlock(nn.Module): def __init__( self, - config: Glm4vVisionConfig, - norm_layer: Optional[nn.Module] = None, + dim: int, + intermediate_dim: int, + num_heads: int, + attn_implementation: Optional[str] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + num_dummy_heads: int = 0, + rms_norm_eps: float = 1e-5, ) -> None: - super().__init__( - dim=config.hidden_size, - intermediate_dim=config.out_hidden_size, - num_heads=config.num_heads, - hidden_act=config.hidden_act, - norm_layer=norm_layer, + super().__init__() + self.norm1 = RMSNorm(dim, eps=rms_norm_eps) + self.norm2 = RMSNorm(dim, eps=rms_norm_eps) + + if attn_implementation is None: + softmax_in_single_precision = False + qkv_backend = None + flatten_batch = True + elif attn_implementation == "sdpa": + softmax_in_single_precision = False + qkv_backend = "sdpa" + flatten_batch = True + elif attn_implementation == "flash_attention_2": + softmax_in_single_precision = False + qkv_backend = "triton_attn" + flatten_batch = True + elif attn_implementation == "eager": + softmax_in_single_precision = True + qkv_backend = "sdpa" + flatten_batch = True + elif attn_implementation == "flash_attention_3": + softmax_in_single_precision = False + qkv_backend = "fa3" + flatten_batch = True + + self.attn = VisionAttention( + embed_dim=dim, + num_heads=num_heads, + projection_size=dim, + use_qkv_parallel=True, + rotary_embed="normal", + proj_bias=True, + qkv_backend=qkv_backend, + softmax_in_single_precision=softmax_in_single_precision, + flatten_batch=flatten_batch, quant_config=quant_config, - prefix=prefix, - num_dummy_heads=config.num_dummy_heads, - rms_norm_eps=config.rms_norm_eps, + prefix=add_prefix("attn", prefix), + num_dummy_heads=num_dummy_heads, ) - self.mlp = Glm4vVisionMLP( - config.hidden_size, - config.out_hidden_size, - bias=False, + dim, + intermediate_dim, quant_config=quant_config, prefix=add_prefix("mlp", prefix), ) + def forward( + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + position_embeddings: torch.Tensor, + ) -> torch.Tensor: + S, B, H = x.shape + # norm1: flatten to 2D -> [S*B, H], then reshape back + x2d = x.reshape(-1, H) + hidden_states = self.norm1(x2d).reshape(S, B, H) + + # Attention expects [B, S, H] + hidden_states = rearrange(hidden_states, "s b h -> b s h") + attn = self.attn( + hidden_states, + cu_seqlens=cu_seqlens, + position_embeddings=position_embeddings, + ) + attn = rearrange(attn, "b s h -> s b h") + + # norm2 with fused residual-add: also 2D + attn2d = attn.reshape(-1, H) + x_norm_2d, x_after_add_2d = self.norm2(x2d, residual=attn2d) + x_norm = x_norm_2d.reshape(S, B, H) + x_after_add = x_after_add_2d.reshape(S, B, H) + + # MLP and final residual + mlp_out = self.mlp(x_norm) + x = x_after_add + mlp_out + return x + class Glm4vVisionPatchEmbed(nn.Module): def __init__( @@ -320,7 +402,6 @@ class Glm4vVisionModel(nn.Module): def __init__( self, vision_config: Glm4vVisionConfig, - norm_eps: float = 1e-6, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: @@ -344,17 +425,18 @@ def __init__( hidden_size=self.hidden_size, ) - norm_layer = partial(Glm4vRMSNorm, eps=norm_eps) head_dim = self.hidden_size // self.num_heads self.rotary_pos_emb = Glm4vVisionRotaryEmbedding(head_dim // 2) self.blocks = nn.ModuleList( [ Glm4vVisionBlock( - config=vision_config, - norm_layer=norm_layer, + dim=self.hidden_size, + intermediate_dim=self.out_hidden_size, + num_heads=self.num_heads, quant_config=quant_config, prefix=add_prefix(f"blocks.{layer_idx}", prefix), + rms_norm_eps=vision_config.rms_norm_eps, ) for layer_idx in range(depth) ] @@ -461,29 +543,30 @@ def forward(self, x: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: return x -class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration): +class Glm4vForConditionalGeneration(nn.Module): def __init__( self, config: Glm4vConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: - nn.Module.__init__(self) + super().__init__() self.config = config - vision_utils.update_vit_attn_dummy_heads_config(self.config) - self.model = Glm4Model( - config, - quant_config, - prefix=add_prefix("model", prefix), - ) self.visual = Glm4vVisionModel( config.vision_config, - norm_eps=getattr(config, "rms_norm_eps", 1e-5), quant_config=quant_config, prefix=add_prefix("visual", prefix), ) + vision_utils.update_vit_attn_dummy_heads_config(self.config) + + self.model = Glm4Model( + config, + quant_config=quant_config, + prefix=add_prefix("model", prefix), + ) + if config.tie_word_embeddings: self.lm_head = self.model.embed_tokens else: @@ -494,13 +577,18 @@ def __init__( prefix=add_prefix("lm_head", prefix), ) + self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling + self.logits_processor = LogitsProcessor(config) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) - self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling # For EAGLE3 support self.capture_aux_hidden_states = False + def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs): + pattern = MultiModalityDataPaddingPatternMultimodalTokens() + return pattern.pad_input_tokens(input_ids, mm_inputs) + def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: pixel_values = torch.cat( [item.feature.squeeze(0) for item in items], dim=0 @@ -542,20 +630,60 @@ def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: video_embeds = torch.split(video_embeds, split_sizes) return torch.cat(video_embeds) - def _update_hf_config(self): - """update hf config to ensure vision attention num_attention_heads is divisible by tp_size""" - tp_size = get_attention_tp_size() - num_heads = self.config.vision_config.num_heads - head_dim = self.config.vision_config.hidden_size // num_heads - num_dummy_heads = 0 + def get_input_embeddings(self): + return self.model.embed_tokens - if num_heads % tp_size != 0: - num_dummy_heads = ( - (num_heads + tp_size - 1) // tp_size - ) * tp_size - num_heads + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + get_embedding: bool = False, + ): + """Run forward pass for GLM-4.1V. + + Args: + input_ids: Flattened (concatenated) input_ids corresponding to a + batch. + positions: Flattened (concatenated) position ids corresponding to a + batch. + **NOTE**: If mrope is enabled (default setting for GLM-4.1V + opensource models), the shape will be `(3, seq_len)`, + otherwise it will be `(seq_len,). + (Use input_metadata.mrope_positions to replace it) + """ + if self.is_mrope_enabled: + positions = forward_batch.mrope_positions + + if not ( + forward_batch.forward_mode.is_decode() + or not forward_batch.contains_image_inputs() + ): + if self.is_mrope_enabled: + assert positions.ndim == 2 and positions.size(0) == 3, ( + "multimodal section rotary embedding requires " + f"(3, seq_len) positions, but got {positions.size()}" + ) + + hidden_states = general_mm_embed_routine( + input_ids=input_ids, + forward_batch=forward_batch, + language_model=self.model, + multimodal_model=self, + positions=positions, + ) - setattr(self.config.vision_config, "head_dim", head_dim) - setattr(self.config.vision_config, "num_dummy_heads", num_dummy_heads) + aux_hidden_states = None + if self.capture_aux_hidden_states: + hidden_states, aux_hidden_states = hidden_states + + if not get_embedding: + return self.logits_processor( + input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states + ) + else: + return self.pooler(hidden_states, forward_batch) def _pad_vit_attn_dummy_heads(self, name: str, loaded_weight: torch.Tensor): """pad attn qkv weights for dummy heads""" @@ -598,13 +726,12 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ] params_dict = dict(self.named_parameters(remove_duplicate=False)) for name, loaded_weight in weights: - if "language_model." in name: - name = name.replace("language_model.", "") - if "model.visual." in name: - name = name.replace("model.visual.", "visual.") - if "rotary_emb.inv_freq" in name: continue + if "language_model" in name: + name = name.replace(r"model.language_model.", r"model.") + if "model.visual." in name: + name = name.replace("model.visual.", "visual.") for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: @@ -639,5 +766,19 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ) weight_loader(param, loaded_weight) + def get_embed_and_head(self): + return self.model.embed_tokens.weight, self.lm_head.weight + + def set_embed_and_head(self, embed, head): + del self.model.embed_tokens.weight + self.model.embed_tokens.weight = embed + if self.config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + del self.lm_head.weight + self.lm_head.weight = head + torch.cuda.empty_cache() + torch.cuda.synchronize() + EntryClass = [Glm4vForConditionalGeneration] diff --git a/python/sglang/srt/models/glm4v_moe.py b/python/sglang/srt/models/glm4v_moe.py index 2e07823bf718..8ec27dc0e153 100644 --- a/python/sglang/srt/models/glm4v_moe.py +++ b/python/sglang/srt/models/glm4v_moe.py @@ -53,7 +53,6 @@ def __init__( ) self.visual = Glm4vVisionModel( config.vision_config, - norm_eps=getattr(config, "rms_norm_eps", 1e-5), quant_config=quant_config, prefix=add_prefix("visual", prefix), )