From 56d5cc7d2a34198929331b5968a028557a6c4676 Mon Sep 17 00:00:00 2001 From: ppraneth Date: Fri, 1 Aug 2025 18:38:39 +0530 Subject: [PATCH 1/4] Support for Ling-Lite and Ling-Plus MoE LLM --- python/sglang/srt/models/bailing_moe.py | 448 ++++++++++++++++++++++ test/srt/models/test_generation_models.py | 1 + 2 files changed, 449 insertions(+) create mode 100644 python/sglang/srt/models/bailing_moe.py diff --git a/python/sglang/srt/models/bailing_moe.py b/python/sglang/srt/models/bailing_moe.py new file mode 100644 index 000000000000..81ae033d4481 --- /dev/null +++ b/python/sglang/srt/models/bailing_moe.py @@ -0,0 +1,448 @@ +from collections.abc import Iterable +from typing import Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from torch import nn +from transformers.configuration_utils import PretrainedConfig + +from sglang.python.sglang.srt.layers.moe.ep_moe import layer +from sglang.python.sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE +from sglang.python.sglang.srt.layers.moe.topk import TopK +from sglang.python.sglang.srt.layers.utils import PPMissingLayer +from sglang.python.sglang.srt.models.transformers import maybe_prefix +from sglang.python.sglang.srt.utils import add_prefix, make_layers +from sglang.srt.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + parallel_state, + tensor_model_parallel_all_reduce, +) +from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size +from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) +from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope +from sglang.srt.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader + + +class BailingAttention(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + layer_id: int = 0, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.hidden_size = config.hidden_size + + attn_tp_rank = get_attention_tp_rank() + attn_tp_size = get_attention_tp_size() + + self.total_num_heads = config.num_attention_heads + self.total_kv_heads = config.num_key_value_heads + + assert self.total_num_heads % attn_tp_size == 0 + assert self.total_kv_heads % attn_tp_size == 0 + assert self.total_num_heads >= self.total_kv_heads + + self.num_heads = self.total_num_heads // attn_tp_size + self.head_dim = config.head_dim or (self.hidden_size // self.total_num_heads) + self.q_size_per_rank = self.head_dim * self.num_heads + + self.num_kv_heads = self.total_kv_heads // attn_tp_size + self.kv_size_per_rank = self.num_kv_heads * self.head_dim + self.scale = self.head_dim**-0.5 + + self.query_key_value = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_kv_heads, + bias=(config.use_bias or config.use_qkv_bias), + quant_config=quant_config, + tp_rank=attn_tp_rank, + tp_size=attn_tp_size, + prefix=f"{prefix}.query_key_value", + ) + + self.dense = RowParallelLinear( + self.total_num_heads * self.head_dim, + self.hidden_size, + bias=config.use_bias, + quant_config=quant_config, + tp_rank=attn_tp_rank, + tp_size=attn_tp_size, + prefix=f"{prefix}.dense", + ) + + self.attn = RadixAttention( + self.num_heads, + self.head_dim, + self.scale, + num_kv_heads=self.num_kv_heads, + layer_id=layer_id, + quant_config=quant_config, + prefix=add_prefix("attn", prefix), + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=config.max_position_embeddings, + base=config.rope_theta, + is_neox_style=True, + rope_scaling=config.rope_scaling, + ) + + def forward( + self, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + position_ids: torch.Tensor, + ) -> torch.Tensor: + + qkv, _ = self.query_key_value(hidden_states) + q, k, v = qkv.split( + [self.q_size_per_rank, self.kv_size_per_rank, self.kv_size_per_rank], dim=-1 + ) + + q, k = self.rotary_emb(position_ids, q, k) + + context_layer = self.attn(q, k, v, forward_batch) + + attn_output, _ = self.dense(context_layer) + return attn_output + + +class BailingMLP(nn.Module): + + def __init__( + self, + intermediate_size: int, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: Optional[bool] = True, + prefix: str = "", + tp_rank: Optional[int] = None, + tp_size: Optional[int] = None, + ) -> None: + super().__init__() + self.tp_size = tp_size + self.gate_up_proj = MergedColumnParallelLinear( + config.hidden_size, + [intermediate_size] * 2, + bias=config.use_bias, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + tp_rank=tp_rank, + tp_size=tp_size, + ) + self.down_proj = RowParallelLinear( + intermediate_size, + config.hidden_size, + bias=config.use_bias, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj", + tp_rank=tp_rank, + tp_size=tp_size, + ) + self.act_fn = SiluAndMul() + + def forward(self, x): + x, _ = self.gate_up_proj(x) + x = self.act_fn(x) + x, _ = self.down_proj(x) + return x + + +class BailingMoE(nn.Module): + + def __init__( + self, + intermediate_size: int, + layer_id: int, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: Optional[bool] = True, + prefix: str = "", + ): + super().__init__() + + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.num_experts = config.num_experts + self.top_k = config.num_experts_per_tok + self.norm_expert_prob = config.norm_topk_prob + self.hidden_size = config.hidden_size + self.quant_config = quant_config + self.num_shared_experts = config.num_shared_experts + # Gate always runs at half / full precision for now. + self.gate = ReplicatedLinear( + self.hidden_size, self.num_experts, bias=False, quant_config=None + ) + self.topk = TopK( + top_k=self.top_k, + renormalize=False, + ) + + self.experts = FusedMoE( + num_experts=self.num_experts, + top_k=self.top_k, + layer_id=layer_id, + hidden_size=self.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=self.norm_expert_prob, + quant_config=quant_config, + prefix=f"{prefix}.experts", + ) + + if self.num_shared_experts > 0: + intermediate_size = config.moe_intermediate_size * self.num_shared_experts + self.shared_experts = BailingMLP( + intermediate_size=intermediate_size, + config=config, + quant_config=quant_config, + reduce_results=False, + prefix=f"{prefix}.shared_experts", + ) + else: + self.shared_experts = None + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_size = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_size) + if self.num_shared_experts > 0: + shared_output = self.shared_experts(hidden_states) + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + topk_output = self.topk(hidden_states, router_logits) + final_hidden_states = self.experts(hidden_states, topk_output) + + if self.num_shared_experts > 0: + final_hidden_states = final_hidden_states + shared_output + + if self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) + return final_hidden_states.view(num_tokens, hidden_size) + + +class BailingMoeBlock(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + layer_id: int = 0, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + hidden_size = config.hidden_size + intermediate_size = config.intermediate_size + self.input_layernorm = RMSNorm(hidden_size, eps=config.rms_norm_eps) + self.attention = BailingAttention( + config, layer_id, quant_config, prefix=f"{prefix}.attention" + ) + self.post_attention_layernorm = RMSNorm(hidden_size, eps=config.rms_norm_eps) + self.mlp = BailingMoE( + config=config, + layer_id=layer_id, + quant_config=quant_config, + prefix=add_prefix("mlp", prefix), + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: torch.Tensor, + forward_batch: ForwardBatch, + residual: Optional[torch.Tensor], + ) -> torch.Tensor: + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + + hidden_states = self.attention( + hidden_states=hidden_states, + position_ids=position_ids, + forward_batch=forward_batch, + ) + + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + mlp_output = self.mlp(hidden_states) + hidden_states = residual + mlp_output + return hidden_states, residual + + +class BailingMoeModel(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + + self.config = config + self.quant_config = quant_config + self.vocab_size = config.vocab_size + self.embed_dim = config.hidden_size + self.pp_group = parallel_state.get_pp_group() + + if self.pp_group.is_first_rank: + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=add_prefix("embed_tokens", prefix), + ) + + else: + self.embed_tokens = PPMissingLayer() + + self.embedding_dropout = torch.nn.Dropout(config.embedding_dropout) + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda idx, prefix: BailingMoeBlock( + config=config, + quant_config=quant_config, + prefix=prefix, + ), + prefix=f"{prefix}.layers", + ) + + if parallel_state.get_pp_group().is_last_rank: + self.norm = RMSNorm(self.embed_dim, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor]: + + if input_embeds is None: + + hidden_states = self.get_input_embeddings(input_ids) + else: + hidden_states = input_embeds + residual = None + for layer in self.layers: + hidden_states, residual = layer( + hidden_states, + position_ids, + residual, + forward_batch, + ) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class BailingMoeForCausalLM(nn.Module): + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.model = BailingMoeModel(config=config, quant_config=quant_config) + self.lm_head = ParallelLMHead( + num_embeddings=config.vocab_size, + embedding_dim=config.hidden_size, + quant_config=quant_config, + ) + if config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + + self.logits_processor = LogitsProcessor(config) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> LogitsProcessorOutput: + + hidden_states = self.model(input_ids, positions, forward_batch, inputs_embeds) + return self.logits_processor( + input_ids, hidden_states, self.lm_head, forward_batch + ) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts, + ) + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "model.word_embeddings" in name: + name = name.replace("model.word_embeddings", "model.embed_tokens") + + is_loaded = False + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name in name and "expert" not in name: + name = name.replace(weight_name, param_name) + param = params_dict[name] + param.weight_loader(param, loaded_weight, shard_id) + is_loaded = True + break + if is_loaded: + continue + + for p_name, w_name, e_id, s_id in expert_params_mapping: + if w_name in name: + name = name.replace(w_name, p_name) + param = params_dict[name] + param.weight_loader( + param, loaded_weight, name, shard_id=s_id, expert_id=e_id + ) + is_loaded = True + break + if is_loaded: + continue + + if name in params_dict: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + +EntryClass = BailingMoeForCausalLM diff --git a/test/srt/models/test_generation_models.py b/test/srt/models/test_generation_models.py index f8acf4b189ea..e56e8282c775 100644 --- a/test/srt/models/test_generation_models.py +++ b/test/srt/models/test_generation_models.py @@ -67,6 +67,7 @@ class ModelCase: ModelCase("openai-community/gpt2"), ModelCase("microsoft/phi-1_5", trust_remote_code=True), ModelCase("adept/persimmon-8b-chat"), + ModelCase("inclusionAI/Ling-lite"), ModelCase("microsoft/Phi-3-small-8k-instruct", trust_remote_code=True), ModelCase("allenai/OLMo-2-1124-7B-Instruct", skip_long_prompt=True), ModelCase("ibm-granite/granite-3.0-2b-instruct", skip_long_prompt=True), From 6f0d78124ce5adf31fbfd5a753e4215130af2f3c Mon Sep 17 00:00:00 2001 From: ppraneth Date: Fri, 1 Aug 2025 19:19:56 +0530 Subject: [PATCH 2/4] Support for Ling-Lite and Ling-Plus MoE LLM --- python/sglang/srt/models/bailing_moe.py | 50 ++++++++++--------------- 1 file changed, 20 insertions(+), 30 deletions(-) diff --git a/python/sglang/srt/models/bailing_moe.py b/python/sglang/srt/models/bailing_moe.py index 81ae033d4481..a0e1acd90015 100644 --- a/python/sglang/srt/models/bailing_moe.py +++ b/python/sglang/srt/models/bailing_moe.py @@ -6,12 +6,6 @@ from torch import nn from transformers.configuration_utils import PretrainedConfig -from sglang.python.sglang.srt.layers.moe.ep_moe import layer -from sglang.python.sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE -from sglang.python.sglang.srt.layers.moe.topk import TopK -from sglang.python.sglang.srt.layers.utils import PPMissingLayer -from sglang.python.sglang.srt.models.transformers import maybe_prefix -from sglang.python.sglang.srt.utils import add_prefix, make_layers from sglang.srt.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -28,15 +22,21 @@ RowParallelLinear, ) from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput +from sglang.srt.layers.moe.ep_moe import layer +from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE +from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.rotary_embedding import get_rope +from sglang.srt.layers.utils import PPMissingLayer from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.transformers import maybe_prefix +from sglang.srt.utils import add_prefix, make_layers class BailingAttention(nn.Module): @@ -173,34 +173,25 @@ def forward(self, x): class BailingMoE(nn.Module): - def __init__( self, - intermediate_size: int, - layer_id: int, config: PretrainedConfig, + layer_id: int, quant_config: Optional[QuantizationConfig] = None, - reduce_results: Optional[bool] = True, prefix: str = "", ): super().__init__() - self.tp_size = get_tensor_model_parallel_world_size() - self.tp_rank = get_tensor_model_parallel_rank() self.num_experts = config.num_experts self.top_k = config.num_experts_per_tok - self.norm_expert_prob = config.norm_topk_prob self.hidden_size = config.hidden_size - self.quant_config = quant_config self.num_shared_experts = config.num_shared_experts - # Gate always runs at half / full precision for now. + self.norm_expert_prob = config.norm_topk_prob + self.gate = ReplicatedLinear( self.hidden_size, self.num_experts, bias=False, quant_config=None ) - self.topk = TopK( - top_k=self.top_k, - renormalize=False, - ) + self.topk = TopK(top_k=self.top_k, renormalize=self.norm_expert_prob) self.experts = FusedMoE( num_experts=self.num_experts, @@ -208,10 +199,8 @@ def __init__( layer_id=layer_id, hidden_size=self.hidden_size, intermediate_size=config.moe_intermediate_size, - reduce_results=False, - renormalize=self.norm_expert_prob, quant_config=quant_config, - prefix=f"{prefix}.experts", + prefix=add_prefix("experts", prefix), ) if self.num_shared_experts > 0: @@ -220,28 +209,29 @@ def __init__( intermediate_size=intermediate_size, config=config, quant_config=quant_config, - reduce_results=False, - prefix=f"{prefix}.shared_experts", + prefix=add_prefix("shared_experts", prefix), ) else: self.shared_experts = None def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - num_tokens, hidden_size = hidden_states.shape - hidden_states = hidden_states.view(-1, hidden_size) + orig_shape = hidden_states.shape + hidden_states = hidden_states.view(-1, self.hidden_size) + + shared_output = None if self.num_shared_experts > 0: shared_output = self.shared_experts(hidden_states) - # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) topk_output = self.topk(hidden_states, router_logits) final_hidden_states = self.experts(hidden_states, topk_output) - if self.num_shared_experts > 0: + if shared_output is not None: final_hidden_states = final_hidden_states + shared_output if self.tp_size > 1: final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) - return final_hidden_states.view(num_tokens, hidden_size) + return final_hidden_states.view(orig_shape) class BailingMoeBlock(nn.Module): @@ -322,7 +312,7 @@ def __init__( self.embedding_dropout = torch.nn.Dropout(config.embedding_dropout) - self.start_layer, self.end_layer, self.layers = make_layers( + self.layers = make_layers( config.num_hidden_layers, lambda idx, prefix: BailingMoeBlock( config=config, From 24ad70a60f6b9887c81d3b0b82809c43c78eacea Mon Sep 17 00:00:00 2001 From: ppraneth Date: Fri, 1 Aug 2025 21:03:54 +0530 Subject: [PATCH 3/4] Support for Ling-Lite and Ling-Plus MoE LLM --- docs/supported_models/generative_models.md | 1 + python/sglang/srt/models/bailing_moe.py | 227 ++++++++++----------- test/srt/models/test_generation_models.py | 2 +- 3 files changed, 109 insertions(+), 121 deletions(-) diff --git a/docs/supported_models/generative_models.md b/docs/supported_models/generative_models.md index 375e24cd453b..62a2ed14d52c 100644 --- a/docs/supported_models/generative_models.md +++ b/docs/supported_models/generative_models.md @@ -47,5 +47,6 @@ in the GitHub search bar. | **MiMo** (7B series) | `XiaomiMiMo/MiMo-7B-RL` | Xiaomi's reasoning-optimized model series, leverages Multiple-Token Prediction for faster inference. | | **Arcee AFM-4.5B** | `arcee-ai/AFM-4.5B-Base` | Arcee's foundational model series for real world reliability and edge deployments. | | **Persimmon** (8B) | `adept/persimmon-8b-chat` | Adept’s open 8B model with a 16K context window and fast inference; trained for broad usability and licensed under Apache 2.0. | +| **Ling** (16.8B–290B) | `inclusionAI/Ling-lite`, `inclusionAI/Ling-plus` | InclusionAI’s open MoE models: Ling-Lite (16.8B total, 2.75B active) and Ling-Plus (290B total, 28.8B active), built for flexibility and high performance across NLP and complex reasoning tasks. Open-source and scalable, fostering rapid innovation and wide adoption. | | **Granite 3.0, 3.1** (IBM) | `ibm-granite/granite-3.1-8b-instruct` | IBM's open dense foundation models optimized for reasoning, code, and business AI use cases. Integrated with Red Hat and watsonx systems. | | **Granite 3.0 MoE** (IBM) | `ibm-granite/granite-3.0-3b-a800m-instruct` | IBM’s Mixture-of-Experts models offering strong performance with cost-efficiency. MoE expert routing designed for enterprise deployment at scale. | diff --git a/python/sglang/srt/models/bailing_moe.py b/python/sglang/srt/models/bailing_moe.py index a0e1acd90015..73e5a9a16366 100644 --- a/python/sglang/srt/models/bailing_moe.py +++ b/python/sglang/srt/models/bailing_moe.py @@ -1,5 +1,8 @@ +# Copyright 2023-2024 SGLang Team +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/bailing_moe.py + from collections.abc import Iterable -from typing import Optional, Tuple, Union +from typing import Optional, Tuple import torch import torch.nn.functional as F @@ -7,13 +10,10 @@ from transformers.configuration_utils import PretrainedConfig from sglang.srt.distributed import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, - parallel_state, tensor_model_parallel_all_reduce, ) from sglang.srt.layers.activation import SiluAndMul -from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( MergedColumnParallelLinear, @@ -21,21 +21,18 @@ ReplicatedLinear, RowParallelLinear, ) -from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput -from sglang.srt.layers.moe.ep_moe import layer -from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.rotary_embedding import get_rope -from sglang.srt.layers.utils import PPMissingLayer from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader -from sglang.srt.models.transformers import maybe_prefix from sglang.srt.utils import add_prefix, make_layers @@ -50,35 +47,30 @@ def __init__( ): super().__init__() self.hidden_size = config.hidden_size - - attn_tp_rank = get_attention_tp_rank() - attn_tp_size = get_attention_tp_size() + tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = config.num_attention_heads - self.total_kv_heads = config.num_key_value_heads + self.total_num_kv_heads = config.num_key_value_heads - assert self.total_num_heads % attn_tp_size == 0 - assert self.total_kv_heads % attn_tp_size == 0 - assert self.total_num_heads >= self.total_kv_heads + assert self.total_num_heads % tp_size == 0 + assert self.total_num_kv_heads % tp_size == 0 - self.num_heads = self.total_num_heads // attn_tp_size + self.num_heads = self.total_num_heads // tp_size self.head_dim = config.head_dim or (self.hidden_size // self.total_num_heads) - self.q_size_per_rank = self.head_dim * self.num_heads + self.q_size = self.num_heads * self.head_dim - self.num_kv_heads = self.total_kv_heads // attn_tp_size - self.kv_size_per_rank = self.num_kv_heads * self.head_dim + self.num_kv_heads = self.total_num_kv_heads // tp_size + self.kv_size = self.num_kv_heads * self.head_dim self.scale = self.head_dim**-0.5 self.query_key_value = QKVParallelLinear( self.hidden_size, self.head_dim, self.total_num_heads, - self.total_kv_heads, + self.total_num_kv_heads, bias=(config.use_bias or config.use_qkv_bias), quant_config=quant_config, - tp_rank=attn_tp_rank, - tp_size=attn_tp_size, - prefix=f"{prefix}.query_key_value", + prefix=add_prefix("query_key_value", prefix), ) self.dense = RowParallelLinear( @@ -86,9 +78,7 @@ def __init__( self.hidden_size, bias=config.use_bias, quant_config=quant_config, - tp_rank=attn_tp_rank, - tp_size=attn_tp_size, - prefix=f"{prefix}.dense", + prefix=add_prefix("dense", prefix), ) self.attn = RadixAttention( @@ -113,25 +103,19 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - forward_batch: ForwardBatch, position_ids: torch.Tensor, + forward_batch: ForwardBatch, ) -> torch.Tensor: - qkv, _ = self.query_key_value(hidden_states) - q, k, v = qkv.split( - [self.q_size_per_rank, self.kv_size_per_rank, self.kv_size_per_rank], dim=-1 - ) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(position_ids, q, k) - context_layer = self.attn(q, k, v, forward_batch) - attn_output, _ = self.dense(context_layer) return attn_output class BailingMLP(nn.Module): - def __init__( self, intermediate_size: int, @@ -139,19 +123,14 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, reduce_results: Optional[bool] = True, prefix: str = "", - tp_rank: Optional[int] = None, - tp_size: Optional[int] = None, ) -> None: super().__init__() - self.tp_size = tp_size self.gate_up_proj = MergedColumnParallelLinear( config.hidden_size, [intermediate_size] * 2, bias=config.use_bias, quant_config=quant_config, - prefix=f"{prefix}.gate_up_proj", - tp_rank=tp_rank, - tp_size=tp_size, + prefix=add_prefix("gate_up_proj", prefix), ) self.down_proj = RowParallelLinear( intermediate_size, @@ -159,9 +138,7 @@ def __init__( bias=config.use_bias, quant_config=quant_config, reduce_results=reduce_results, - prefix=f"{prefix}.down_proj", - tp_rank=tp_rank, - tp_size=tp_size, + prefix=add_prefix("down_proj", prefix), ) self.act_fn = SiluAndMul() @@ -173,6 +150,7 @@ def forward(self, x): class BailingMoE(nn.Module): + def __init__( self, config: PretrainedConfig, @@ -187,10 +165,12 @@ def __init__( self.hidden_size = config.hidden_size self.num_shared_experts = config.num_shared_experts self.norm_expert_prob = config.norm_topk_prob + self.moe_intermediate_size = config.moe_intermediate_size self.gate = ReplicatedLinear( self.hidden_size, self.num_experts, bias=False, quant_config=None ) + self.topk = TopK(top_k=self.top_k, renormalize=self.norm_expert_prob) self.experts = FusedMoE( @@ -198,17 +178,21 @@ def __init__( top_k=self.top_k, layer_id=layer_id, hidden_size=self.hidden_size, - intermediate_size=config.moe_intermediate_size, + intermediate_size=self.moe_intermediate_size, + reduce_results=False, quant_config=quant_config, prefix=add_prefix("experts", prefix), ) if self.num_shared_experts > 0: - intermediate_size = config.moe_intermediate_size * self.num_shared_experts + shared_intermediate_size = ( + self.moe_intermediate_size * self.num_shared_experts + ) self.shared_experts = BailingMLP( - intermediate_size=intermediate_size, + intermediate_size=shared_intermediate_size, config=config, quant_config=quant_config, + reduce_results=False, prefix=add_prefix("shared_experts", prefix), ) else: @@ -216,21 +200,22 @@ def __init__( def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: orig_shape = hidden_states.shape - hidden_states = hidden_states.view(-1, self.hidden_size) + hidden_states_flat = hidden_states.view(-1, self.hidden_size) shared_output = None - if self.num_shared_experts > 0: - shared_output = self.shared_experts(hidden_states) + if self.shared_experts is not None: + shared_output = self.shared_experts(hidden_states_flat) - router_logits, _ = self.gate(hidden_states) - topk_output = self.topk(hidden_states, router_logits) - final_hidden_states = self.experts(hidden_states, topk_output) + router_logits, _ = self.gate(hidden_states_flat) + topk_output = self.topk(hidden_states_flat, router_logits) + final_hidden_states = self.experts(hidden_states_flat, topk_output) if shared_output is not None: final_hidden_states = final_hidden_states + shared_output if self.tp_size > 1: final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) + return final_hidden_states.view(orig_shape) @@ -239,18 +224,18 @@ class BailingMoeBlock(nn.Module): def __init__( self, config: PretrainedConfig, - layer_id: int = 0, + layer_id: int, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): super().__init__() - hidden_size = config.hidden_size - intermediate_size = config.intermediate_size - self.input_layernorm = RMSNorm(hidden_size, eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.attention = BailingAttention( - config, layer_id, quant_config, prefix=f"{prefix}.attention" + config, layer_id, quant_config, prefix=add_prefix("attention", prefix) + ) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps ) - self.post_attention_layernorm = RMSNorm(hidden_size, eps=config.rms_norm_eps) self.mlp = BailingMoE( config=config, layer_id=layer_id, @@ -262,25 +247,31 @@ def forward( self, hidden_states: torch.Tensor, position_ids: torch.Tensor, - forward_batch: ForwardBatch, residual: Optional[torch.Tensor], - ) -> torch.Tensor: + forward_batch: ForwardBatch, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Pre-normalization and residual connection for the attention block if residual is None: residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) + normed_hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm(hidden_states, residual) + normed_hidden_states, residual = self.input_layernorm( + hidden_states, residual + ) - hidden_states = self.attention( - hidden_states=hidden_states, + attn_output = self.attention( + hidden_states=normed_hidden_states, position_ids=position_ids, forward_batch=forward_batch, ) - hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) - mlp_output = self.mlp(hidden_states) - hidden_states = residual + mlp_output - return hidden_states, residual + # Pre-normalization and residual connection for the MLP block + normed_hidden_states, residual = self.post_attention_layernorm( + attn_output, residual + ) + mlp_output = self.mlp(normed_hidden_states) + + return mlp_output, residual class BailingMoeModel(nn.Module): @@ -292,43 +283,30 @@ def __init__( prefix: str = "", ): super().__init__() - self.config = config - self.quant_config = quant_config + self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_dim = config.hidden_size - self.pp_group = parallel_state.get_pp_group() - - if self.pp_group.is_first_rank: - self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=add_prefix("embed_tokens", prefix), - ) - - else: - self.embed_tokens = PPMissingLayer() + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + prefix=add_prefix("embed_tokens", prefix), + ) self.embedding_dropout = torch.nn.Dropout(config.embedding_dropout) self.layers = make_layers( config.num_hidden_layers, lambda idx, prefix: BailingMoeBlock( config=config, + layer_id=idx, quant_config=quant_config, prefix=prefix, ), - prefix=f"{prefix}.layers", + prefix=add_prefix("layers", prefix), ) - if parallel_state.get_pp_group().is_last_rank: - self.norm = RMSNorm(self.embed_dim, eps=config.rms_norm_eps) - else: - self.norm = PPMissingLayer() - - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.embed_tokens(input_ids) + self.norm = RMSNorm(self.embed_dim, eps=config.rms_norm_eps) def forward( self, @@ -336,13 +314,12 @@ def forward( position_ids: torch.Tensor, forward_batch: ForwardBatch, input_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor]: - + ) -> torch.Tensor: if input_embeds is None: - - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_tokens(input_ids) else: hidden_states = input_embeds + residual = None for layer in self.layers: hidden_states, residual = layer( @@ -357,6 +334,7 @@ def forward( class BailingMoeForCausalLM(nn.Module): + def __init__( self, config: PretrainedConfig, @@ -381,14 +359,14 @@ def forward( positions: torch.Tensor, forward_batch: ForwardBatch, inputs_embeds: Optional[torch.Tensor] = None, - ) -> LogitsProcessorOutput: - + ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, forward_batch, inputs_embeds) return self.logits_processor( input_ids, hidden_states, self.lm_head, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), @@ -403,36 +381,45 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: - if "model.word_embeddings" in name: - name = name.replace("model.word_embeddings", "model.embed_tokens") - is_loaded = False + if ( + hasattr(self.config, "norm_head") + and self.config.norm_head + and "lm_head.weight" in name + ): + loaded_weight = F.normalize(loaded_weight, dim=0, p=2, eps=1e-7) + + if "model.word_embeddings.weight" == name: + name = "model.embed_tokens.weight" + for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name in name and "expert" not in name: - name = name.replace(weight_name, param_name) - param = params_dict[name] + if weight_name in name and "mlp.experts" not in name: + full_param_name = name.replace(weight_name, param_name) + param = params_dict[full_param_name] param.weight_loader(param, loaded_weight, shard_id) - is_loaded = True break - if is_loaded: - continue + else: + for p_name, w_name, e_id, s_id in expert_params_mapping: + if w_name in name and "mlp.experts" in name: + full_param_name = name.replace(w_name, p_name) + param = params_dict[full_param_name] + param.weight_loader( + param, + loaded_weight, + full_param_name, + shard_id=s_id, + expert_id=e_id, + ) + break + else: + if name.endswith(".bias") and name not in params_dict: + continue - for p_name, w_name, e_id, s_id in expert_params_mapping: - if w_name in name: - name = name.replace(w_name, p_name) param = params_dict[name] - param.weight_loader( - param, loaded_weight, name, shard_id=s_id, expert_id=e_id + weight_loader = getattr( + param, "weight_loader", default_weight_loader ) - is_loaded = True - break - if is_loaded: - continue - - if name in params_dict: - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) + weight_loader(param, loaded_weight) EntryClass = BailingMoeForCausalLM diff --git a/test/srt/models/test_generation_models.py b/test/srt/models/test_generation_models.py index e56e8282c775..eb6763c67724 100644 --- a/test/srt/models/test_generation_models.py +++ b/test/srt/models/test_generation_models.py @@ -67,7 +67,7 @@ class ModelCase: ModelCase("openai-community/gpt2"), ModelCase("microsoft/phi-1_5", trust_remote_code=True), ModelCase("adept/persimmon-8b-chat"), - ModelCase("inclusionAI/Ling-lite"), + ModelCase("inclusionAI/Ling-lite", trust_remote_code=True), ModelCase("microsoft/Phi-3-small-8k-instruct", trust_remote_code=True), ModelCase("allenai/OLMo-2-1124-7B-Instruct", skip_long_prompt=True), ModelCase("ibm-granite/granite-3.0-2b-instruct", skip_long_prompt=True), From 40964740597e32724b346ceabade39334faa6fd2 Mon Sep 17 00:00:00 2001 From: Praneth Paruchuri <34855725+ppraneth@users.noreply.github.com> Date: Fri, 1 Aug 2025 21:14:54 +0530 Subject: [PATCH 4/4] Update docs/supported_models/generative_models.md Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- docs/supported_models/generative_models.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/supported_models/generative_models.md b/docs/supported_models/generative_models.md index 62a2ed14d52c..4f65c872ae2f 100644 --- a/docs/supported_models/generative_models.md +++ b/docs/supported_models/generative_models.md @@ -47,6 +47,6 @@ in the GitHub search bar. | **MiMo** (7B series) | `XiaomiMiMo/MiMo-7B-RL` | Xiaomi's reasoning-optimized model series, leverages Multiple-Token Prediction for faster inference. | | **Arcee AFM-4.5B** | `arcee-ai/AFM-4.5B-Base` | Arcee's foundational model series for real world reliability and edge deployments. | | **Persimmon** (8B) | `adept/persimmon-8b-chat` | Adept’s open 8B model with a 16K context window and fast inference; trained for broad usability and licensed under Apache 2.0. | -| **Ling** (16.8B–290B) | `inclusionAI/Ling-lite`, `inclusionAI/Ling-plus` | InclusionAI’s open MoE models: Ling-Lite (16.8B total, 2.75B active) and Ling-Plus (290B total, 28.8B active), built for flexibility and high performance across NLP and complex reasoning tasks. Open-source and scalable, fostering rapid innovation and wide adoption. | +| **Ling** (16.8B–290B) | `inclusionAI/Ling-lite`, `inclusionAI/Ling-plus` | InclusionAI’s open MoE models. Ling-Lite has 16.8B total / 2.75B active parameters, and Ling-Plus has 290B total / 28.8B active parameters. They are designed for high performance on NLP and complex reasoning tasks. | | **Granite 3.0, 3.1** (IBM) | `ibm-granite/granite-3.1-8b-instruct` | IBM's open dense foundation models optimized for reasoning, code, and business AI use cases. Integrated with Red Hat and watsonx systems. | | **Granite 3.0 MoE** (IBM) | `ibm-granite/granite-3.0-3b-a800m-instruct` | IBM’s Mixture-of-Experts models offering strong performance with cost-efficiency. MoE expert routing designed for enterprise deployment at scale. |