diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index c15c0213b520..6584980f6dc2 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -162,20 +162,9 @@ def __init__( prefix=f"{prefix}.o_proj", ) - is_neox_style = True - is_gguf = quant_config and quant_config.get_name() == "gguf" - if is_gguf and config.model_type == "llama": - is_neox_style = False - - self.rotary_emb = get_rope( - self.head_dim, - rotary_dim=self.head_dim, - max_position=max_position_embeddings, - base=rope_theta, - rope_scaling=rope_scaling, - is_neox_style=is_neox_style, - partial_rotary_factor=self.partial_rotary_factor, - ) + self._init_rotary_emb(config, + rope_scaling=rope_scaling, + quant_config=quant_config) if hasattr(config, "interleaved_sliding_window"): interleaved_sliding_window = config.interleaved_sliding_window @@ -214,6 +203,24 @@ def forward( output, _ = self.o_proj(attn_output) return output + def _init_rotary_emb(self, config: LlamaConfig, + rope_scaling: Optional[dict[str, Any]], + quant_config: Optional[QuantizationConfig]) -> None: + is_neox_style = True + is_gguf = quant_config and quant_config.get_name() == "gguf" + if is_gguf and self.config.model_type == "llama": + is_neox_style = False + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=self.rope_theta, + rope_scaling=rope_scaling, + is_neox_style=is_neox_style, + partial_rotary_factor=self.partial_rotary_factor, + ) + class LlamaDecoderLayer(nn.Module): diff --git a/vllm/model_executor/models/nemotron_nas.py b/vllm/model_executor/models/nemotron_nas.py index f4d5a77f2086..9808fe05558e 100644 --- a/vllm/model_executor/models/nemotron_nas.py +++ b/vllm/model_executor/models/nemotron_nas.py @@ -23,18 +23,20 @@ # limitations under the License. """Inference-only deci model compatible with HuggingFace weights.""" from collections.abc import Iterable -from typing import Optional, Union +from typing import Any, Optional, Union import torch from torch import nn from transformers import LlamaConfig +from vllm.attention import AttentionType from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -62,6 +64,48 @@ def _find_multiple(n: int, k: int) -> int: return n + k - (n % k) +class DeciLMAttention(LlamaAttention): + + def __init__( + self, + config: LlamaConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[dict[str, Any]] = None, + max_position_embeddings: int = 8192, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + bias_o_proj: bool = False, + cache_config: Optional[CacheConfig] = None, + prefix: str = "", + attn_type: str = AttentionType.DECODER, + ) -> None: + super().__init__(config, hidden_size, num_heads, num_kv_heads, + rope_theta, rope_scaling, max_position_embeddings, + quant_config, bias, bias_o_proj, cache_config, prefix, + attn_type) + + def _init_rotary_emb(self, config, rope_scaling: Optional[dict[str, Any]], + quant_config: Optional[QuantizationConfig]) -> None: + # Enables YARN for Mistral and LLaMA4 derivatives. + is_neox_style = True + if hasattr(config, "position_embedding_type"): + is_neox_style = config.position_embedding_type not in [ + "mistral_yarn", "rope_llama4" + ] + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=self.rope_theta, + rope_scaling=rope_scaling, + is_neox_style=is_neox_style, + partial_rotary_factor=self.partial_rotary_factor) + + class DeciLMDecoderLayer(nn.Module): def __init__( @@ -98,7 +142,7 @@ def __init__( if not self._is_no_op_attention: num_kv_heads = (config.num_attention_heads // block_config.attention.n_heads_in_group) - self.self_attn = LlamaAttention( + self.self_attn = DeciLMAttention( config=config, hidden_size=self.hidden_size, num_heads=config.num_attention_heads,