diff --git a/docs/supported_models/generative_models.md b/docs/supported_models/generative_models.md index c5358a6ece24..3d75fa3077d3 100644 --- a/docs/supported_models/generative_models.md +++ b/docs/supported_models/generative_models.md @@ -39,7 +39,7 @@ in the GitHub search bar. | **OLMoE** (Open MoE) | `allenai/OLMoE-1B-7B-0924` | Allen AI’s open Mixture-of-Experts model (7B total, 1B active parameters) delivering state-of-the-art results with sparse expert activation. | | **MiniMax-M2** (M2, M2.1) | `minimax/MiniMax-M2`, `minimax/MiniMax-M2.1` | MiniMax’s SOTA LLM for coding & agentic workflows. | | **StableLM** (3B, 7B) | `stabilityai/stablelm-tuned-alpha-7b` | StabilityAI’s early open-source LLM (3B & 7B) for general text generation; a demonstration model with basic instruction-following ability. | -| **Command-R** (Cohere) | `CohereForAI/c4ai-command-r-v01` | Cohere’s open conversational LLM (Command series) optimized for long context, retrieval-augmented generation, and tool use. | +| **Command-(R,A)** (Cohere) | `CohereLabs/c4ai-command-r-v01`, `CohereLabs/c4ai-command-r7b-12-2024`, `CohereLabs/c4ai-command-a-03-2025` | Cohere’s open conversational LLM (Command series) optimized for long context, retrieval-augmented generation, and tool use. | | **DBRX** (Databricks) | `databricks/dbrx-instruct` | Databricks’ 132B-parameter MoE model (36B active) trained on 12T tokens; competes with GPT-3.5 quality as a fully open foundation model. | | **Grok** (xAI) | `xai-org/grok-1` | xAI’s grok-1 model known for vast size(314B parameters) and high quality; integrated in SGLang for high-performance inference. | | **ChatGLM** (GLM-130B family) | `THUDM/chatglm2-6b` | Zhipu AI’s bilingual chat model (6B) excelling at Chinese-English dialogue; fine-tuned for conversational quality and alignment. | @@ -64,4 +64,4 @@ in the GitHub search bar. | **NVIDIA Nemotron Nano 2.0** | `nvidia/NVIDIA-Nemotron-Nano-9B-v2` | The [NVIDIA Nemotron](https://www.nvidia.com/en-us/ai-data-science/foundation-models/nemotron/) family of multimodal models provides state-of-the-art reasoning models specifically designed for enterprise-ready AI agents. `Nemotron-Nano-9B-v2` is a hybrid Mamba-Transformer language model designed to increase throughput for reasoning workloads while achieving state-of-the-art accuracy compared to similarly-sized models. | | **StarCoder2** (3B-15B) | `bigcode/starcoder2-7b` | StarCoder2 is a family of open large language models (LLMs) specialized for code generation and understanding. It is the successor to StarCoder, jointly developed by the BigCode project (a collaboration between Hugging Face, ServiceNow Research, and other contributors). | | **Jet-Nemotron** | `jet-ai/Jet-Nemotron-2B` | Jet-Nemotron is a new family of hybrid-architecture language models that surpass state-of-the-art open-source full-attention language models, while achieving significant efficiency gains. | -| **Trinity** (Nano, Mini) | `arcee-ai/Trinity-Mini` | Arcee's foundational MoE Trinity family of models, open weights under Apache 2.0. | +| **Trinity** (Nano, Mini) | `arcee-ai/Trinity-Mini` | Arcee's foundational MoE Trinity family of models, open weights under Apache 2.0. | diff --git a/python/sglang/srt/models/commandr.py b/python/sglang/srt/models/commandr.py index ebbf8ed64029..7c799f5f8400 100644 --- a/python/sglang/srt/models/commandr.py +++ b/python/sglang/srt/models/commandr.py @@ -43,7 +43,7 @@ import torch.utils.checkpoint from torch import nn from torch.nn.parameter import Parameter -from transformers import PretrainedConfig +from transformers import Cohere2Config, CohereConfig, PretrainedConfig from sglang.srt.distributed import ( get_tensor_model_parallel_rank, @@ -198,12 +198,23 @@ def __init__( rope_scaling=self.rope_scaling, is_neox_style=False, ) + + self.v1 = isinstance(config, CohereConfig) + self.v2 = isinstance(config, Cohere2Config) + + # Model v2 has interleaved sliding windows, v1 does not + if self.v2 and config.layer_types[layer_id] == "sliding_attention": + self.sliding_window_size = config.sliding_window + else: + self.sliding_window_size = -1 + self.attn = RadixAttention( self.num_heads, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, layer_id=layer_id, + sliding_window_size=self.sliding_window_size, quant_config=quant_config, prefix=add_prefix("attn", prefix), ) @@ -235,7 +246,9 @@ def forward( q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) if self.use_qk_norm: q, k = self._apply_qk_norm(q, k) - q, k = self.rotary_emb(positions, q, k) + # Model v1 uses RoPE throughout, Model v2 uses RoPE only for SWA layers + if self.v1 or self.sliding_window_size > 0: + q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v, forward_batch) output, _ = self.o_proj(attn_output) return output @@ -348,7 +361,8 @@ def __init__( super().__init__() self.config = config self.quant_config = quant_config - self.logits_processor = LogitsProcessor(config) + self.logit_scale = getattr(config, "logit_scale", None) + self.logits_processor = LogitsProcessor(config, logit_scale=self.logit_scale) self.model = CohereModel( config, quant_config, prefix=add_prefix("model", prefix) )