Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 21 additions & 14 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,20 +162,9 @@ def __init__(
prefix=f"{prefix}.o_proj",
)

is_neox_style = True
is_gguf = quant_config and quant_config.get_name() == "gguf"
if is_gguf and config.model_type == "llama":
is_neox_style = False

self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
is_neox_style=is_neox_style,
partial_rotary_factor=self.partial_rotary_factor,
)
self._init_rotary_emb(config,
rope_scaling=rope_scaling,
quant_config=quant_config)

if hasattr(config, "interleaved_sliding_window"):
interleaved_sliding_window = config.interleaved_sliding_window
Expand Down Expand Up @@ -214,6 +203,24 @@ def forward(
output, _ = self.o_proj(attn_output)
return output

def _init_rotary_emb(self, config: LlamaConfig,
rope_scaling: Optional[dict[str, Any]],
quant_config: Optional[QuantizationConfig]) -> None:
is_neox_style = True
is_gguf = quant_config and quant_config.get_name() == "gguf"
if is_gguf and self.config.model_type == "llama":
is_neox_style = False

self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
base=self.rope_theta,
rope_scaling=rope_scaling,
is_neox_style=is_neox_style,
partial_rotary_factor=self.partial_rotary_factor,
)


class LlamaDecoderLayer(nn.Module):

Expand Down
48 changes: 46 additions & 2 deletions vllm/model_executor/models/nemotron_nas.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,20 @@
# limitations under the License.
"""Inference-only deci model compatible with HuggingFace weights."""
from collections.abc import Iterable
from typing import Optional, Union
from typing import Any, Optional, Union

import torch
from torch import nn
from transformers import LlamaConfig

from vllm.attention import AttentionType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
Expand Down Expand Up @@ -62,6 +64,48 @@ def _find_multiple(n: int, k: int) -> int:
return n + k - (n % k)


class DeciLMAttention(LlamaAttention):

def __init__(
self,
config: LlamaConfig,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[dict[str, Any]] = None,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
bias_o_proj: bool = False,
cache_config: Optional[CacheConfig] = None,
prefix: str = "",
attn_type: str = AttentionType.DECODER,
) -> None:
super().__init__(config, hidden_size, num_heads, num_kv_heads,
rope_theta, rope_scaling, max_position_embeddings,
quant_config, bias, bias_o_proj, cache_config, prefix,
attn_type)

def _init_rotary_emb(self, config, rope_scaling: Optional[dict[str, Any]],
quant_config: Optional[QuantizationConfig]) -> None:
# Enables YARN for Mistral and LLaMA4 derivatives.
is_neox_style = True
if hasattr(config, "position_embedding_type"):
is_neox_style = config.position_embedding_type not in [
"mistral_yarn", "rope_llama4"
]

self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
base=self.rope_theta,
rope_scaling=rope_scaling,
is_neox_style=is_neox_style,
partial_rotary_factor=self.partial_rotary_factor)


class DeciLMDecoderLayer(nn.Module):

def __init__(
Expand Down Expand Up @@ -98,7 +142,7 @@ def __init__(
if not self._is_no_op_attention:
num_kv_heads = (config.num_attention_heads //
block_config.attention.n_heads_in_group)
self.self_attn = LlamaAttention(
self.self_attn = DeciLMAttention(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
Expand Down
Loading