From 24025bc5cf7f9c42ad5395d221c4d1c59dfd6c78 Mon Sep 17 00:00:00 2001 From: pengmeng Date: Tue, 17 Jun 2025 14:22:37 +0800 Subject: [PATCH 1/5] add hunyuan support --- python/sglang/srt/models/hunyuan.py | 816 ++++++++++++++++++++++++++++ 1 file changed, 816 insertions(+) create mode 100644 python/sglang/srt/models/hunyuan.py diff --git a/python/sglang/srt/models/hunyuan.py b/python/sglang/srt/models/hunyuan.py new file mode 100644 index 00000000000..48f3ef03576 --- /dev/null +++ b/python/sglang/srt/models/hunyuan.py @@ -0,0 +1,816 @@ +# coding=utf-8 +# Copyright 2024 The HunYuan team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only HunYuan model compatible with HuggingFace weights.""" +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union + +import re +import torch +from torch import nn +from transformers import PretrainedConfig + +from sglang.srt.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + ColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( + get_compressed_tensors_cache_scale) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.utils import is_hip + +from .interfaces import SupportsLoRA +from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers + + +def _is_moe(config: PretrainedConfig) -> bool: + if getattr(config, "num_experts", None) and \ + ((isinstance(config.num_experts, int) and config.num_experts > 1) or \ + (isinstance(config.num_experts, list) and max(config.num_experts) > 1)): + return True + else: + return False + +def _get_cla_factor(config: PretrainedConfig) -> int: + if not getattr(config, "use_cla", False): + return 1 + return getattr(config, "cla_share_factor", 1) + + +class HunYuanMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + prefix: str = "", + reduce_results: bool = True, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + input_size=hidden_size, + output_sizes=[intermediate_size] * 2, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj") + self.down_proj = RowParallelLinear(input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + reduce_results=reduce_results) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class HunYuanSparseMoeBlock(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + layer_id: int = -1, + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + + if self.tp_size > config.num_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {config.num_experts}.") + + # Get layer_id topk if config.moe_topk is a list + if isinstance(config.moe_topk, list): + assert layer_id >= 0 + assert len(config.moe_topk) > layer_id + top_k = config.moe_topk[layer_id] + else: + top_k = config.moe_topk + + # If it is moe, moe_intermediate_size is preferred + intermediate_size = config.intermediate_size + if config.moe_intermediate_size is not None: + intermediate_size = config.moe_intermediate_size if isinstance(config.moe_intermediate_size, int) else config.moe_intermediate_size[layer_id] + + self.experts = FusedMoE(num_experts=config.num_experts, + top_k=top_k, + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + reduce_results=False, + renormalize=True if top_k > 1 else False, + quant_config=quant_config) + + self.gate = ReplicatedLinear(config.hidden_size, + config.num_experts, + bias=False, + quant_config=None) + if config.use_mixed_mlp_moe > 0: + # Get layer_id num_shared_expert if config.num_shared_expert is a list + if isinstance(config.num_shared_expert, list): + assert layer_id >= 0 + assert len(config.num_shared_expert) > layer_id + num_shared_expert = config.num_shared_expert[layer_id] + else: + num_shared_expert = config.num_shared_expert + + self.shared_mlp = HunYuanMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size * num_shared_expert, + hidden_act=config.hidden_act, + quant_config=quant_config, + reduce_results=False, + ) + else: + self.shared_mlp = None + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # NOTE: hidden_states can have either 1D or 2D shape. + orig_shape = hidden_states.shape + hidden_dim = hidden_states.shape[-1] + hidden_states = hidden_states.view(-1, hidden_dim) + shared_output = None + if self.shared_mlp is not None: + shared_output = self.shared_mlp(hidden_states) + + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = self.experts(hidden_states=hidden_states, + router_logits=router_logits) + if shared_output is not None: + final_hidden_states = final_hidden_states + shared_output + if self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + + return final_hidden_states.view(orig_shape) + + +class HunYuanAttention(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + cache_config: Optional[CacheConfig] = None, + prefix: str = "", + attention_type: str = "self", + layer_id: int = -1, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + # MistralConfig has an optional head_dim introduced by Mistral-Nemo + self.head_dim = getattr(config, "head_dim", + self.hidden_size // self.total_num_heads) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + self.use_qk_norm = getattr(config, "use_qk_norm", False) + self.attention_type = attention_type + self.layer_id = layer_id + + if attention_type == "self": + self.qkv_proj = QKVParallelLinear( + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + elif attention_type == "cross": + self.q_proj = ColumnParallelLinear( + hidden_size, + hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.q_proj", + ) + else: + raise RuntimeError("Not support attnention type") + + self.o_proj = RowParallelLinear( + input_size=self.total_num_heads * self.head_dim, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + is_neox_style = True + if quant_config is not None and quant_config.get_name() == "gguf": + is_neox_style = False + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=is_neox_style, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) + + if self.use_qk_norm: + self.query_layernorm = RMSNorm(self.head_dim, + eps=config.rms_norm_eps) + self.key_layernorm = RMSNorm(self.head_dim, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + kv_states: Optional[Tuple[torch.Tensor]] = None, + ) -> torch.Tensor: + if self.attention_type == "self": + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + ori_k = k + if self.use_qk_norm: + q = self.query_layernorm(q.view(-1, self.num_heads, self.head_dim).contiguous()) + k = self.key_layernorm(k.view(-1, self.num_kv_heads, self.head_dim).contiguous()) + elif self.attention_type == "cross": + assert kv_states is not None + ori_k, v = kv_states # use last layer kv, + k = ori_k + q, _ = self.q_proj(hidden_states) + k_tmp = torch.empty_like(k) # Todo: reduant rotary embedding + q, _ = self.rotary_emb(positions, q, k_tmp) + if self.use_qk_norm: + q = self.query_layernorm(q.view(-1, self.num_heads, self.head_dim).contiguous()) + k = self.key_layernorm(k.view(-1, self.num_kv_heads, self.head_dim).contiguous()) + else: + raise RuntimeError("Not support attnention type") + + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output, (ori_k, v) + + +class HunYuanDecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + layer_id: int = -1, + ) -> None: + super().__init__() + assert layer_id >= 0 + self.layer_id = layer_id + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size \ + if isinstance(config.intermediate_size, int) else config.intermediate_size[layer_id] + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + # Support abacusai/Smaug-72B-v0.1 with attention_bias + # Support internlm/internlm-7b with bias + attention_bias = getattr(config, "attention_bias", False) or getattr( + config, "bias", False) + cla_factor = _get_cla_factor(config) + attention_type = "cross" \ + if layer_id >= 0 and layer_id % cla_factor != 0 else "self" + self.self_attn = HunYuanAttention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=getattr(config, "num_key_value_heads", + config.num_attention_heads), + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=attention_bias, + cache_config=cache_config, + prefix=f"{prefix}.self_attn", + attention_type=attention_type, + layer_id=layer_id, + ) + if _is_moe(config): + self.mlp = HunYuanSparseMoeBlock(config=config, + quant_config=quant_config, + layer_id=layer_id,) + else: + self.mlp = HunYuanMLP( + hidden_size=self.hidden_size, + intermediate_size=self.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + bias=getattr(config, "mlp_bias", False), + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + kv_states: Optional[Tuple[torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states, ori_kv_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + kv_states=kv_states, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual, ori_kv_states + + +class HunYuanModel(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + if get_pp_group().is_first_rank or (config.tie_word_embeddings + and get_pp_group().is_last_rank): + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + quant_config=quant_config, + ) + else: + self.embed_tokens = PPMissingLayer() + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: HunYuanDecoderLayer(config=config, + layer_id=int( + prefix.split(".")[-1]), + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix), + prefix=f"{prefix}.layers") + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + cla_factor = _get_cla_factor(self.config) + prev_kv_states = None + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual, kv_states = layer( + positions, + hidden_states, + kv_caches[i - self.start_layer], + # kv_caches[(i - self.start_layer) // cla_factor], + attn_metadata, + residual, + prev_kv_states, + ) + + if (i - self.start_layer) % cla_factor == 0: + prev_kv_states = kv_states + else: + prev_kv_states = None + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class HunYuanForCausalLM(nn.Module, SupportsLoRA): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens", + "lm_head" + ] + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + bitsandbytes_stacked_params_mapping = { + # shard_name, weight_name, index + "q_proj": ("qkv_proj", 0), + "k_proj": ("qkv_proj", 1), + "v_proj": ("qkv_proj", 2), + "gate_proj": ("gate_up_proj", 0), + "up_proj": ("gate_up_proj", 1), + } + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + super().__init__() + + self.config = config + self.lora_config = lora_config + + self.model = HunYuanModel(config, + cache_config, + quant_config, + lora_config=lora_config, + prefix="model") + if get_pp_group().is_last_rank: + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + quant_config=quant_config, + ) + if config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, + logit_scale) + self.sampler = Sampler() + else: + self.lm_head = PPMissingLayer() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + model_output = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + return model_output + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + "hidden_states": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + "residual": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + }) + + def _split_qkv_weight(self, qkv: torch.Tensor): + num_attention_heads = self.config.num_attention_heads + num_kv_heads = getattr(self.config, "num_key_value_heads", self.config.num_attention_heads) + num_key_value_groups = num_attention_heads // num_kv_heads + hidden_size = self.config.hidden_size + attention_head_dim = self.config.hidden_size // num_attention_heads + + qkv = qkv.reshape(num_kv_heads, num_key_value_groups + 2, attention_head_dim, hidden_size) + q, k, v = torch.split(qkv, (num_key_value_groups, 1, 1), dim=1) + q = q.reshape(-1, hidden_size) + k = k.reshape(-1, hidden_size) + v = v.reshape(-1, hidden_size) + return torch.concat((q, k, v)) + # return qkv.reshape((num_kv_heads, num_key_value_groups+2 , attention_head_dim, hidden_size)).permute((1,0,2,3)).reshape((-1, hidden_size)), + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + cla_factor = _get_cla_factor(self.config) + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + + num_attention_heads = self.config.num_attention_heads + num_kv_heads = getattr(self.config, "num_key_value_heads", self.config.num_attention_heads) + split_params_mapping = [ + (".gate_up_proj", ".gate_and_up_proj", 2, [(1, 1), (0, 1)], None), + (".qkv_proj", ".qkv_proj", + num_attention_heads + num_kv_heads * 2, + [("q", num_attention_heads), ("k", num_kv_heads), ("v", num_kv_heads)], + self._split_qkv_weight, + ), + ] + + if _is_moe(self.config): + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts) + else: + expert_params_mapping = {} + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if "gate_proj_bias" in name: + name = name.replace("gate_proj_bias", "gate_proj.bias") + if "up_proj_bias" in name: + name = name.replace("up_proj_bias", "up_proj.bias") + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + # With tie_word_embeddings, we can skip lm_head.weight + # The weight might appear unnecessarily in the files if the model is + # processed with quantization, LoRA, fine-tuning, etc. + if self.config.tie_word_embeddings and "lm_head.weight" in name: + continue + if scale_name := get_compressed_tensors_cache_scale(name): + # Loading kv cache scales for compressed-tensors quantization + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = loaded_weight[0] + weight_loader(param, loaded_weight) + continue + + is_found = False + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + if "mlp.experts" in name: + continue + # cross layer only have q_proj, skip qkv pack + if weight_name == ".q_proj": + match = re.search(r'layers\.\d+', name) + if match: + layer_id = int(match.group(0).split('.')[-1]) + if cla_factor > 1 and layer_id % cla_factor != 0: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + + is_found = True + break + if is_found: + continue + + for (param_name, weight_name, den, split_param, func) in split_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + assert loaded_weight.shape[0] % den == 0 + units = loaded_weight.shape[0] // den + + param = params_dict[name] + weight_loader = param.weight_loader + offset = 0 + for shard_id, num in split_param: + new_offset = offset + num * units + if func: + weight_loader(param, func(loaded_weight)[offset:new_offset], shard_id) + else: + weight_loader(param, loaded_weight[offset:new_offset], shard_id) + offset = new_offset + + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if is_pp_missing_parameter(name, self): + continue + + if "mlp.gate.wg." in name: + name = name.replace("wg.", "") + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + # If this function is called, it should always initialize KV cache scale + # factors (or else raise an exception). Thus, handled exceptions should + # make sure to leave KV cache scale factors in a known good (dummy) state + def load_kv_cache_scales(self, quantization_param_path: str) -> None: + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + for layer_idx, scaling_factor in kv_cache_scales_loader( + quantization_param_path, tp_rank, tp_size, + self.config.num_hidden_layers, + self.config.__class__.model_type): + if not isinstance(self.model.layers[layer_idx], nn.Identity): + layer_self_attn = self.model.layers[layer_idx].self_attn + + if is_hip(): + # The scaling factor convention we are assuming is + # quantized_value * scaling_factor ~= true_value + # which is consistent with the practice of setting + # scaling_factor = tensor_amax / FPtype_max + scaling_factor *= 2 + if hasattr(layer_self_attn, "kv_scale"): + layer_self_attn.attn._kv_scale = scaling_factor + else: + raise RuntimeError("Self attention has no KV cache scaling " + "factor attribute!") From 6a9313b0fa6fdd8ecfbaff8447a6ccf0b3e8307c Mon Sep 17 00:00:00 2001 From: Peng Meng Date: Tue, 17 Jun 2025 12:24:48 +0000 Subject: [PATCH 2/5] hunyuan 1 bs test works --- python/sglang/srt/models/hunyuan.py | 256 +++++++++------------------- 1 file changed, 82 insertions(+), 174 deletions(-) diff --git a/python/sglang/srt/models/hunyuan.py b/python/sglang/srt/models/hunyuan.py index 48f3ef03576..feed66cc689 100644 --- a/python/sglang/srt/models/hunyuan.py +++ b/python/sglang/srt/models/hunyuan.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only HunYuan model compatible with HuggingFace weights.""" +import logging +from dataclasses import dataclass +from enum import Enum, auto from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import re @@ -22,31 +25,31 @@ from sglang.srt.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) -from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import FusedMoE -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, +from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, ReplicatedLinear, ColumnParallelLinear, RowParallelLinear) -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.moe.fused_moe_triton import FusedMoE +from sglang.srt.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( - get_compressed_tensors_cache_scale) -from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput -from vllm.model_executor.layers.vocab_parallel_embedding import ( +from sglang.srt.layers.radix_attention import RadixAttention + +from sglang.srt.layers.rotary_embedding import get_rope +from sglang.srt.layers.sampler import Sampler +from sglang.srt.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors -from vllm.utils import is_hip +from sglang.srt.model_loader.weight_utils import ( + kv_cache_scales_loader, maybe_remap_kv_scale_name) +from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.utils import add_prefix, is_hip -from .interfaces import SupportsLoRA -from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers +expert_distribution_recorder = ExpertDistributionRecorder() def _is_moe(config: PretrainedConfig) -> bool: @@ -195,7 +198,6 @@ def __init__( max_position_embeddings: int = 8192, quant_config: Optional[QuantizationConfig] = None, bias: bool = False, - cache_config: Optional[CacheConfig] = None, prefix: str = "", attention_type: str = "self", layer_id: int = -1, @@ -269,12 +271,12 @@ def __init__( rope_scaling=rope_scaling, is_neox_style=is_neox_style, ) - self.attn = Attention(self.num_heads, + self.attn = RadixAttention(self.num_heads, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config) + layer_id=layer_id, + prefix=f"{prefix}.attn",) if self.use_qk_norm: self.query_layernorm = RMSNorm(self.head_dim, @@ -286,8 +288,7 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, + forward_batch: ForwardBatch, kv_states: Optional[Tuple[torch.Tensor]] = None, ) -> torch.Tensor: if self.attention_type == "self": @@ -296,8 +297,10 @@ def forward( q, k = self.rotary_emb(positions, q, k) ori_k = k if self.use_qk_norm: - q = self.query_layernorm(q.view(-1, self.num_heads, self.head_dim).contiguous()) - k = self.key_layernorm(k.view(-1, self.num_kv_heads, self.head_dim).contiguous()) + #q = self.query_layernorm(q.view(-1, self.num_heads, self.head_dim).contiguous()) + #k = self.key_layernorm(k.view(-1, self.num_kv_heads, self.head_dim).contiguous()) + q = self.query_layernorm(q.reshape(-1, self.head_dim).contiguous()) + k = self.key_layernorm(k.reshape(-1, self.head_dim).contiguous()) elif self.attention_type == "cross": assert kv_states is not None ori_k, v = kv_states # use last layer kv, @@ -311,7 +314,7 @@ def forward( else: raise RuntimeError("Not support attnention type") - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v, forward_batch) output, _ = self.o_proj(attn_output) return output, (ori_k, v) @@ -321,7 +324,6 @@ class HunYuanDecoderLayer(nn.Module): def __init__( self, config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", layer_id: int = -1, @@ -358,7 +360,6 @@ def __init__( max_position_embeddings=max_position_embeddings, quant_config=quant_config, bias=attention_bias, - cache_config=cache_config, prefix=f"{prefix}.self_attn", attention_type=attention_type, layer_id=layer_id, @@ -385,8 +386,7 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, + forward_batch: ForwardBatch, residual: Optional[torch.Tensor], kv_states: Optional[Tuple[torch.Tensor]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -400,8 +400,7 @@ def forward( hidden_states, ori_kv_states = self.self_attn( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, + forward_batch=forward_batch, kv_states=kv_states, ) @@ -417,41 +416,31 @@ class HunYuanModel(nn.Module): def __init__( self, config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, prefix: str = "", ) -> None: super().__init__() self.config = config self.padding_idx = config.pad_token_id - lora_vocab = (lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0 - self.vocab_size = config.vocab_size + lora_vocab + self.vocab_size = config.vocab_size self.org_vocab_size = config.vocab_size - if get_pp_group().is_first_rank or (config.tie_word_embeddings - and get_pp_group().is_last_rank): - self.embed_tokens = VocabParallelEmbedding( + + self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - quant_config=quant_config, ) - else: - self.embed_tokens = PPMissingLayer() - self.start_layer, self.end_layer, self.layers = make_layers( - config.num_hidden_layers, - lambda prefix: HunYuanDecoderLayer(config=config, - layer_id=int( - prefix.split(".")[-1]), - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix), - prefix=f"{prefix}.layers") - if get_pp_group().is_last_rank: - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - else: - self.norm = PPMissingLayer() + + self.layers = nn.ModuleList( + [HunYuanDecoderLayer(config=config, + layer_id=layer_id, + quant_config=quant_config, + #prefix=prefix + ) + for layer_id in range(config.num_hidden_layers) + ] + ) + + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -460,52 +449,38 @@ def forward( self, input_ids: Optional[torch.Tensor], positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - intermediate_tensors: Optional[IntermediateTensors], - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - if get_pp_group().is_first_rank: - if inputs_embeds is not None: - hidden_states = inputs_embeds - else: - hidden_states = self.get_input_embeddings(input_ids) - residual = None + forward_batch: ForwardBatch, + input_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if input_embeds is not None: + hidden_states = input_embeds else: - assert intermediate_tensors is not None - hidden_states = intermediate_tensors["hidden_states"] - residual = intermediate_tensors["residual"] + hidden_states = self.get_input_embeddings(input_ids) + residual = None + cla_factor = _get_cla_factor(self.config) prev_kv_states = None - for i in range(self.start_layer, self.end_layer): + for i in range(len(self.layers)): layer = self.layers[i] hidden_states, residual, kv_states = layer( positions, hidden_states, - kv_caches[i - self.start_layer], - # kv_caches[(i - self.start_layer) // cla_factor], - attn_metadata, + forward_batch, residual, prev_kv_states, ) - if (i - self.start_layer) % cla_factor == 0: + if False : #(i - self.start_layer) % cla_factor == 0: prev_kv_states = kv_states else: prev_kv_states = None - if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) - hidden_states, _ = self.norm(hidden_states, residual) return hidden_states -class HunYuanForCausalLM(nn.Module, SupportsLoRA): +class HunYuanForCausalLM(nn.Module): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -518,11 +493,6 @@ class HunYuanForCausalLM(nn.Module, SupportsLoRA): ], } - # LoRA specific attributes - supported_lora_modules = [ - "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens", - "lm_head" - ] embedding_modules = { "embed_tokens": "input_embeddings", "lm_head": "output_embeddings", @@ -540,87 +510,41 @@ class HunYuanForCausalLM(nn.Module, SupportsLoRA): def __init__( self, config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() self.config = config - self.lora_config = lora_config self.model = HunYuanModel(config, - cache_config, quant_config, - lora_config=lora_config, prefix="model") - if get_pp_group().is_last_rank: - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size - self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, - config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config else lora_config.lora_vocab_padding_size, - quant_config=quant_config, - ) - if config.tie_word_embeddings: - self.lm_head.weight = self.model.embed_tokens.weight - - logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size, - logit_scale) - self.sampler = Sampler() - else: - self.lm_head = PPMissingLayer() + self.unpadded_vocab_size = config.vocab_size + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + ) + if config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(config, + logit_scale=logit_scale) + self.sampler = Sampler() def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - model_output = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) - return model_output - - def compute_logits( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) - return logits + forward_batch: ForwardBatch, + input_embeds: torch.Tensor=None, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) + return self.logits_processor( + input_ids, hidden_states, self.lm_head, forward_batch + ) - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - - def make_empty_intermediate_tensors( - self, batch_size: int, dtype: torch.dtype, - device: torch.device) -> IntermediateTensors: - return IntermediateTensors({ - "hidden_states": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - "residual": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - }) def _split_qkv_weight(self, qkv: torch.Tensor): num_attention_heads = self.config.num_attention_heads @@ -688,14 +612,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # processed with quantization, LoRA, fine-tuning, etc. if self.config.tie_word_embeddings and "lm_head.weight" in name: continue - if scale_name := get_compressed_tensors_cache_scale(name): - # Loading kv cache scales for compressed-tensors quantization - param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - loaded_weight = loaded_weight[0] - weight_loader(param, loaded_weight) - continue is_found = False for (param_name, weight_name, shard_id) in stacked_params_mapping: @@ -715,9 +631,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if name.endswith(".bias") and name not in params_dict: continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -735,9 +648,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if name.endswith(".bias") and name not in params_dict: continue - if is_pp_missing_parameter(name, self): - continue - assert loaded_weight.shape[0] % den == 0 units = loaded_weight.shape[0] // den @@ -763,8 +673,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): continue name = name.replace(weight_name, param_name) # Skip layers on other devices. - if is_pp_missing_parameter(name, self): - continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, @@ -779,9 +687,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if name is None: continue - if is_pp_missing_parameter(name, self): - continue - if "mlp.gate.wg." in name: name = name.replace("wg.", "") @@ -814,3 +719,6 @@ def load_kv_cache_scales(self, quantization_param_path: str) -> None: else: raise RuntimeError("Self attention has no KV cache scaling " "factor attribute!") + + +EntryClass=HunYuanForCausalLM From 997dc70025780afcbcf0546514ebda6fe68ed59d Mon Sep 17 00:00:00 2001 From: Peng Meng Date: Wed, 18 Jun 2025 08:42:46 +0000 Subject: [PATCH 3/5] add hunyuan custom ROPE --- python/sglang/srt/layers/rotary_embedding.py | 61 +++++++++++++++++--- 1 file changed, 52 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index 8ae191b519b..4686b761b43 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -806,6 +806,38 @@ def forward( key_out = torch.view_as_real(key_ * freqs_ci).flatten(3) return query_out.type_as(query), key_out.type_as(key) +class DynamicNTKAlphaRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with Dynamic NTK scaling. + + Credits to the Reddit users /u/bloc97 and /u/emozilla + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + scaling_alpha: float, + dtype: torch.dtype, + ) -> None: + self.scaling_alpha = scaling_alpha + super().__init__(head_size, rotary_dim, max_position_embeddings, base, + is_neox_style, dtype) + + def _compute_cos_sin_cache(self) -> torch.Tensor: + max_len = self.max_position_embeddings + base = self.base * self.scaling_alpha ** (self.rotary_dim / (self.rotary_dim - 2)) + + inv_freq = self._compute_inv_freq(base) + t = torch.arange(max_len, dtype=torch.float) + + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + return cache class MRotaryEmbedding(RotaryEmbedding): """Rotary Embedding with Multimodal Sections.""" @@ -1151,15 +1183,26 @@ def get_rope( ) elif scaling_type == "dynamic": scaling_factor = rope_scaling["factor"] - rotary_emb = DynamicNTKScalingRotaryEmbedding( - head_size, - rotary_dim, - max_position, - base, - is_neox_style, - scaling_factor, - dtype, - ) + if "alpha" in rope_scaling: + rotary_emb = DynamicNTKAlphaRotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + rope_scaling["alpha"], + dtype + ) + else: + rotary_emb = DynamicNTKScalingRotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + scaling_factor, + dtype, + ) elif scaling_type == "yarn": scaling_factor = rope_scaling["factor"] original_max_position = rope_scaling["original_max_position_embeddings"] From c8d6811673dc9af87fbba0794c563a466f38c11d Mon Sep 17 00:00:00 2001 From: Peng Meng Date: Wed, 18 Jun 2025 12:21:20 +0000 Subject: [PATCH 4/5] fix code format --- python/sglang/srt/layers/rotary_embedding.py | 13 +- python/sglang/srt/models/hunyuan.py | 313 +++++++++++-------- 2 files changed, 189 insertions(+), 137 deletions(-) diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index 4686b761b43..780ff93352d 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -806,6 +806,7 @@ def forward( key_out = torch.view_as_real(key_ * freqs_ci).flatten(3) return query_out.type_as(query), key_out.type_as(key) + class DynamicNTKAlphaRotaryEmbedding(RotaryEmbedding): """RotaryEmbedding extended with Dynamic NTK scaling. @@ -823,12 +824,15 @@ def __init__( dtype: torch.dtype, ) -> None: self.scaling_alpha = scaling_alpha - super().__init__(head_size, rotary_dim, max_position_embeddings, base, - is_neox_style, dtype) + super().__init__( + head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype + ) def _compute_cos_sin_cache(self) -> torch.Tensor: max_len = self.max_position_embeddings - base = self.base * self.scaling_alpha ** (self.rotary_dim / (self.rotary_dim - 2)) + base = self.base * self.scaling_alpha ** ( + self.rotary_dim / (self.rotary_dim - 2) + ) inv_freq = self._compute_inv_freq(base) t = torch.arange(max_len, dtype=torch.float) @@ -839,6 +843,7 @@ def _compute_cos_sin_cache(self) -> torch.Tensor: cache = torch.cat((cos, sin), dim=-1) return cache + class MRotaryEmbedding(RotaryEmbedding): """Rotary Embedding with Multimodal Sections.""" @@ -1191,7 +1196,7 @@ def get_rope( base, is_neox_style, rope_scaling["alpha"], - dtype + dtype, ) else: rotary_emb = DynamicNTKScalingRotaryEmbedding( diff --git a/python/sglang/srt/models/hunyuan.py b/python/sglang/srt/models/hunyuan.py index feed66cc689..c33838fd147 100644 --- a/python/sglang/srt/models/hunyuan.py +++ b/python/sglang/srt/models/hunyuan.py @@ -13,57 +13,67 @@ # limitations under the License. """Inference-only HunYuan model compatible with HuggingFace weights.""" import logging +import re from dataclasses import dataclass from enum import Enum, auto from typing import Any, Dict, Iterable, List, Optional, Tuple, Union -import re import torch from torch import nn from transformers import PretrainedConfig -from sglang.srt.distributed import (get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) +from sglang.srt.distributed import ( + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce, +) from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm -from sglang.srt.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - ColumnParallelLinear, - RowParallelLinear) +from sglang.srt.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.moe.fused_moe_triton import FusedMoE -from sglang.srt.layers.quantization.base_config import ( - QuantizationConfig) +from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention - from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.sampler import Sampler from sglang.srt.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) -from sglang.srt.model_loader.weight_utils import ( - kv_cache_scales_loader, maybe_remap_kv_scale_name) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder from sglang.srt.model_executor.forward_batch_info import ForwardBatch -from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.model_loader.weight_utils import ( + default_weight_loader, + kv_cache_scales_loader, + maybe_remap_kv_scale_name, +) from sglang.srt.utils import add_prefix, is_hip expert_distribution_recorder = ExpertDistributionRecorder() def _is_moe(config: PretrainedConfig) -> bool: - if getattr(config, "num_experts", None) and \ - ((isinstance(config.num_experts, int) and config.num_experts > 1) or \ - (isinstance(config.num_experts, list) and max(config.num_experts) > 1)): + if getattr(config, "num_experts", None) and ( + (isinstance(config.num_experts, int) and config.num_experts > 1) + or (isinstance(config.num_experts, list) and max(config.num_experts) > 1) + ): return True else: return False + def _get_cla_factor(config: PretrainedConfig) -> int: if not getattr(config, "use_cla", False): return 1 - return getattr(config, "cla_share_factor", 1) + return getattr(config, "cla_share_factor", 1) class HunYuanMLP(nn.Module): @@ -84,16 +94,21 @@ def __init__( output_sizes=[intermediate_size] * 2, bias=bias, quant_config=quant_config, - prefix=f"{prefix}.gate_up_proj") - self.down_proj = RowParallelLinear(input_size=intermediate_size, - output_size=hidden_size, - bias=bias, - quant_config=quant_config, - prefix=f"{prefix}.down_proj", - reduce_results=reduce_results) + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + reduce_results=reduce_results, + ) if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now." + ) self.act_fn = SiluAndMul() def forward(self, x): @@ -117,8 +132,9 @@ def __init__( if self.tp_size > config.num_experts: raise ValueError( f"Tensor parallel size {self.tp_size} is greater than " - f"the number of experts {config.num_experts}.") - + f"the number of experts {config.num_experts}." + ) + # Get layer_id topk if config.moe_topk is a list if isinstance(config.moe_topk, list): assert layer_id >= 0 @@ -130,20 +146,25 @@ def __init__( # If it is moe, moe_intermediate_size is preferred intermediate_size = config.intermediate_size if config.moe_intermediate_size is not None: - intermediate_size = config.moe_intermediate_size if isinstance(config.moe_intermediate_size, int) else config.moe_intermediate_size[layer_id] - - self.experts = FusedMoE(num_experts=config.num_experts, - top_k=top_k, - hidden_size=config.hidden_size, - intermediate_size=intermediate_size, - reduce_results=False, - renormalize=True if top_k > 1 else False, - quant_config=quant_config) - - self.gate = ReplicatedLinear(config.hidden_size, - config.num_experts, - bias=False, - quant_config=None) + intermediate_size = ( + config.moe_intermediate_size + if isinstance(config.moe_intermediate_size, int) + else config.moe_intermediate_size[layer_id] + ) + + self.experts = FusedMoE( + num_experts=config.num_experts, + top_k=top_k, + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + reduce_results=False, + renormalize=True if top_k > 1 else False, + quant_config=quant_config, + ) + + self.gate = ReplicatedLinear( + config.hidden_size, config.num_experts, bias=False, quant_config=None + ) if config.use_mixed_mlp_moe > 0: # Get layer_id num_shared_expert if config.num_shared_expert is a list if isinstance(config.num_shared_expert, list): @@ -174,13 +195,13 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - final_hidden_states = self.experts(hidden_states=hidden_states, - router_logits=router_logits) + final_hidden_states = self.experts( + hidden_states=hidden_states, router_logits=router_logits + ) if shared_output is not None: final_hidden_states = final_hidden_states + shared_output if self.tp_size > 1: - final_hidden_states = tensor_model_parallel_all_reduce( - final_hidden_states) + final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) return final_hidden_states.view(orig_shape) @@ -219,8 +240,9 @@ def __init__( assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) # MistralConfig has an optional head_dim introduced by Mistral-Nemo - self.head_dim = getattr(config, "head_dim", - self.hidden_size // self.total_num_heads) + self.head_dim = getattr( + config, "head_dim", self.hidden_size // self.total_num_heads + ) self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 @@ -271,18 +293,18 @@ def __init__( rope_scaling=rope_scaling, is_neox_style=is_neox_style, ) - self.attn = RadixAttention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - layer_id=layer_id, - prefix=f"{prefix}.attn",) + self.attn = RadixAttention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + layer_id=layer_id, + prefix=f"{prefix}.attn", + ) if self.use_qk_norm: - self.query_layernorm = RMSNorm(self.head_dim, - eps=config.rms_norm_eps) - self.key_layernorm = RMSNorm(self.head_dim, - eps=config.rms_norm_eps) + self.query_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.key_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) def forward( self, @@ -297,20 +319,24 @@ def forward( q, k = self.rotary_emb(positions, q, k) ori_k = k if self.use_qk_norm: - #q = self.query_layernorm(q.view(-1, self.num_heads, self.head_dim).contiguous()) - #k = self.key_layernorm(k.view(-1, self.num_kv_heads, self.head_dim).contiguous()) + # q = self.query_layernorm(q.view(-1, self.num_heads, self.head_dim).contiguous()) + # k = self.key_layernorm(k.view(-1, self.num_kv_heads, self.head_dim).contiguous()) q = self.query_layernorm(q.reshape(-1, self.head_dim).contiguous()) k = self.key_layernorm(k.reshape(-1, self.head_dim).contiguous()) elif self.attention_type == "cross": assert kv_states is not None - ori_k, v = kv_states # use last layer kv, + ori_k, v = kv_states # use last layer kv, k = ori_k q, _ = self.q_proj(hidden_states) - k_tmp = torch.empty_like(k) # Todo: reduant rotary embedding + k_tmp = torch.empty_like(k) # Todo: reduant rotary embedding q, _ = self.rotary_emb(positions, q, k_tmp) if self.use_qk_norm: - q = self.query_layernorm(q.view(-1, self.num_heads, self.head_dim).contiguous()) - k = self.key_layernorm(k.view(-1, self.num_kv_heads, self.head_dim).contiguous()) + q = self.query_layernorm( + q.view(-1, self.num_heads, self.head_dim).contiguous() + ) + k = self.key_layernorm( + k.view(-1, self.num_kv_heads, self.head_dim).contiguous() + ) else: raise RuntimeError("Not support attnention type") @@ -332,29 +358,36 @@ def __init__( assert layer_id >= 0 self.layer_id = layer_id self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size \ - if isinstance(config.intermediate_size, int) else config.intermediate_size[layer_id] + self.intermediate_size = ( + config.intermediate_size + if isinstance(config.intermediate_size, int) + else config.intermediate_size[layer_id] + ) rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) if rope_scaling is not None and getattr( - config, "original_max_position_embeddings", None): + config, "original_max_position_embeddings", None + ): rope_scaling["original_max_position_embeddings"] = ( - config.original_max_position_embeddings) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) + config.original_max_position_embeddings + ) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) # Support abacusai/Smaug-72B-v0.1 with attention_bias # Support internlm/internlm-7b with bias attention_bias = getattr(config, "attention_bias", False) or getattr( - config, "bias", False) + config, "bias", False + ) cla_factor = _get_cla_factor(config) - attention_type = "cross" \ - if layer_id >= 0 and layer_id % cla_factor != 0 else "self" + attention_type = ( + "cross" if layer_id >= 0 and layer_id % cla_factor != 0 else "self" + ) self.self_attn = HunYuanAttention( config=config, hidden_size=self.hidden_size, num_heads=config.num_attention_heads, - num_kv_heads=getattr(config, "num_key_value_heads", - config.num_attention_heads), + num_kv_heads=getattr( + config, "num_key_value_heads", config.num_attention_heads + ), rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, @@ -365,9 +398,11 @@ def __init__( layer_id=layer_id, ) if _is_moe(config): - self.mlp = HunYuanSparseMoeBlock(config=config, - quant_config=quant_config, - layer_id=layer_id,) + self.mlp = HunYuanSparseMoeBlock( + config=config, + quant_config=quant_config, + layer_id=layer_id, + ) else: self.mlp = HunYuanMLP( hidden_size=self.hidden_size, @@ -377,10 +412,10 @@ def __init__( bias=getattr(config, "mlp_bias", False), prefix=f"{prefix}.mlp", ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward( self, @@ -395,8 +430,7 @@ def forward( residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states, ori_kv_states = self.self_attn( positions=positions, hidden_states=hidden_states, @@ -405,8 +439,7 @@ def forward( ) # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual, ori_kv_states @@ -426,17 +459,19 @@ def __init__( self.org_vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( - self.vocab_size, - config.hidden_size, - ) + self.vocab_size, + config.hidden_size, + ) self.layers = nn.ModuleList( - [HunYuanDecoderLayer(config=config, - layer_id=layer_id, - quant_config=quant_config, - #prefix=prefix - ) - for layer_id in range(config.num_hidden_layers) + [ + HunYuanDecoderLayer( + config=config, + layer_id=layer_id, + quant_config=quant_config, + # prefix=prefix + ) + for layer_id in range(config.num_hidden_layers) ] ) @@ -458,7 +493,6 @@ def forward( hidden_states = self.get_input_embeddings(input_ids) residual = None - cla_factor = _get_cla_factor(self.config) prev_kv_states = None for i in range(len(self.layers)): @@ -471,7 +505,7 @@ def forward( prev_kv_states, ) - if False : #(i - self.start_layer) % cla_factor == 0: + if False: # (i - self.start_layer) % cla_factor == 0: prev_kv_states = kv_states else: prev_kv_states = None @@ -516,9 +550,7 @@ def __init__( self.config = config - self.model = HunYuanModel(config, - quant_config, - prefix="model") + self.model = HunYuanModel(config, quant_config, prefix="model") self.unpadded_vocab_size = config.vocab_size self.lm_head = ParallelLMHead( config.vocab_size, @@ -529,8 +561,7 @@ def __init__( self.lm_head.weight = self.model.embed_tokens.weight logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(config, - logit_scale=logit_scale) + self.logits_processor = LogitsProcessor(config, logit_scale=logit_scale) self.sampler = Sampler() def forward( @@ -538,22 +569,25 @@ def forward( input_ids: torch.Tensor, positions: torch.Tensor, forward_batch: ForwardBatch, - input_embeds: torch.Tensor=None, + input_embeds: torch.Tensor = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) return self.logits_processor( input_ids, hidden_states, self.lm_head, forward_batch ) - def _split_qkv_weight(self, qkv: torch.Tensor): num_attention_heads = self.config.num_attention_heads - num_kv_heads = getattr(self.config, "num_key_value_heads", self.config.num_attention_heads) + num_kv_heads = getattr( + self.config, "num_key_value_heads", self.config.num_attention_heads + ) num_key_value_groups = num_attention_heads // num_kv_heads hidden_size = self.config.hidden_size attention_head_dim = self.config.hidden_size // num_attention_heads - qkv = qkv.reshape(num_kv_heads, num_key_value_groups + 2, attention_head_dim, hidden_size) + qkv = qkv.reshape( + num_kv_heads, num_key_value_groups + 2, attention_head_dim, hidden_size + ) q, k, v = torch.split(qkv, (num_key_value_groups, 1, 1), dim=1) q = q.reshape(-1, hidden_size) k = k.reshape(-1, hidden_size) @@ -573,10 +607,14 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ] num_attention_heads = self.config.num_attention_heads - num_kv_heads = getattr(self.config, "num_key_value_heads", self.config.num_attention_heads) + num_kv_heads = getattr( + self.config, "num_key_value_heads", self.config.num_attention_heads + ) split_params_mapping = [ (".gate_up_proj", ".gate_and_up_proj", 2, [(1, 1), (0, 1)], None), - (".qkv_proj", ".qkv_proj", + ( + ".qkv_proj", + ".qkv_proj", num_attention_heads + num_kv_heads * 2, [("q", num_attention_heads), ("k", num_kv_heads), ("v", num_kv_heads)], self._split_qkv_weight, @@ -590,7 +628,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", - num_experts=self.config.num_experts) + num_experts=self.config.num_experts, + ) else: expert_params_mapping = {} @@ -602,8 +641,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): name = name.replace("gate_proj_bias", "gate_proj.bias") if "up_proj_bias" in name: name = name.replace("up_proj_bias", "up_proj.bias") - if ("rotary_emb.cos_cached" in name - or "rotary_emb.sin_cached" in name): + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue @@ -614,16 +652,16 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): continue is_found = False - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue if "mlp.experts" in name: continue # cross layer only have q_proj, skip qkv pack if weight_name == ".q_proj": - match = re.search(r'layers\.\d+', name) + match = re.search(r"layers\.\d+", name) if match: - layer_id = int(match.group(0).split('.')[-1]) + layer_id = int(match.group(0).split(".")[-1]) if cla_factor > 1 and layer_id % cla_factor != 0: continue name = name.replace(weight_name, param_name) @@ -640,7 +678,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if is_found: continue - for (param_name, weight_name, den, split_param, func) in split_params_mapping: + for param_name, weight_name, den, split_param, func in split_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -657,7 +695,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): for shard_id, num in split_param: new_offset = offset + num * units if func: - weight_loader(param, func(loaded_weight)[offset:new_offset], shard_id) + weight_loader( + param, func(loaded_weight)[offset:new_offset], shard_id + ) else: weight_loader(param, loaded_weight[offset:new_offset], shard_id) offset = new_offset @@ -675,11 +715,13 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip layers on other devices. param = params_dict[name] weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - name, - shard_id=shard_id, - expert_id=expert_id) + weight_loader( + param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + ) break else: # Remapping the name of FP8 kv-scale. @@ -691,8 +733,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): name = name.replace("wg.", "") param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) # If this function is called, it should always initialize KV cache scale @@ -702,9 +745,12 @@ def load_kv_cache_scales(self, quantization_param_path: str) -> None: tp_size = get_tensor_model_parallel_world_size() tp_rank = get_tensor_model_parallel_rank() for layer_idx, scaling_factor in kv_cache_scales_loader( - quantization_param_path, tp_rank, tp_size, - self.config.num_hidden_layers, - self.config.__class__.model_type): + quantization_param_path, + tp_rank, + tp_size, + self.config.num_hidden_layers, + self.config.__class__.model_type, + ): if not isinstance(self.model.layers[layer_idx], nn.Identity): layer_self_attn = self.model.layers[layer_idx].self_attn @@ -717,8 +763,9 @@ def load_kv_cache_scales(self, quantization_param_path: str) -> None: if hasattr(layer_self_attn, "kv_scale"): layer_self_attn.attn._kv_scale = scaling_factor else: - raise RuntimeError("Self attention has no KV cache scaling " - "factor attribute!") - + raise RuntimeError( + "Self attention has no KV cache scaling " "factor attribute!" + ) + -EntryClass=HunYuanForCausalLM +EntryClass = HunYuanForCausalLM From 7682241cb40bb40ffb356e8e27d04c1e622e2019 Mon Sep 17 00:00:00 2001 From: Peng Meng Date: Thu, 26 Jun 2025 11:39:43 +0800 Subject: [PATCH 5/5] update HunYuan to HunYuanMoEV1 --- python/sglang/srt/models/hunyuan.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/models/hunyuan.py b/python/sglang/srt/models/hunyuan.py index c33838fd147..00300bed54a 100644 --- a/python/sglang/srt/models/hunyuan.py +++ b/python/sglang/srt/models/hunyuan.py @@ -514,7 +514,7 @@ def forward( return hidden_states -class HunYuanForCausalLM(nn.Module): +class HunYuanMoEV1ForCausalLM(nn.Module): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -768,4 +768,4 @@ def load_kv_cache_scales(self, quantization_param_path: str) -> None: ) -EntryClass = HunYuanForCausalLM +EntryClass = HunYuanMoEV1ForCausalLM