Skip to content

Commit 79c92c7

Browse files
authored
[Model] Add Gemma 2 (#5908)
1 parent 736ed38 commit 79c92c7

File tree

9 files changed

+499
-9
lines changed

9 files changed

+499
-9
lines changed

docs/source/models/supported_models.rst

+4
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ Alongside each architecture, we include some popular models that use it.
5555
- Gemma
5656
- :code:`google/gemma-2b`, :code:`google/gemma-7b`, etc.
5757
- ✅︎
58+
* - :code:`Gemma2ForCausalLM`
59+
- Gemma2
60+
- :code:`google/gemma-2-9b`, :code:`google/gemma-2-27b`, etc.
61+
- ✅︎
5862
* - :code:`GPT2LMHeadModel`
5963
- GPT-2
6064
- :code:`gpt2`, :code:`gpt2-xl`, etc.

requirements-common.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ numpy < 2.0.0
66
requests
77
tqdm
88
py-cpuinfo
9-
transformers >= 4.40.0 # Required for StarCoder2 & Llava, Llama 3.
9+
transformers >= 4.42.0 # Required for Gemma 2.
1010
tokenizers >= 0.19.1 # Required for Llama 3.
1111
fastapi
1212
aiohttp

vllm/config.py

+23-7
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from vllm.tracing import is_otel_installed
1515
from vllm.transformers_utils.config import get_config, get_hf_text_config
1616
from vllm.utils import (cuda_device_count_stateless, get_cpu_memory, is_cpu,
17-
is_hip, is_neuron, is_tpu, is_xpu,
17+
is_hip, is_neuron, is_tpu, is_xpu, print_warning_once,
1818
update_environment_variables)
1919

2020
if TYPE_CHECKING:
@@ -141,6 +141,17 @@ def __init__(
141141
code_revision, rope_scaling, rope_theta)
142142
self.hf_text_config = get_hf_text_config(self.hf_config)
143143
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
144+
145+
if (not self.disable_sliding_window
146+
and self.hf_text_config.model_type == "gemma2"
147+
and self.hf_text_config.sliding_window is not None):
148+
print_warning_once(
149+
"Gemma 2 uses sliding window attention for every odd layer, "
150+
"which is currently not supported by vLLM. Disabling sliding "
151+
"window and capping the max length to the sliding window size "
152+
f"({self.hf_text_config.sliding_window}).")
153+
self.disable_sliding_window = True
154+
144155
self.max_model_len = _get_and_verify_max_len(
145156
hf_config=self.hf_text_config,
146157
max_model_len=max_model_len,
@@ -257,8 +268,7 @@ def verify_with_parallel_config(
257268
"BitAndBytes quantization with TP or PP is not supported yet.")
258269

259270
def get_hf_config_sliding_window(self) -> Optional[int]:
260-
"""Get the sliding window size, or None if disabled.
261-
"""
271+
"""Get the sliding window size, or None if disabled."""
262272

263273
# Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in
264274
# addition to sliding window size. We check if that field is present
@@ -1256,10 +1266,16 @@ def _get_and_verify_dtype(
12561266
dtype = dtype.lower()
12571267
if dtype == "auto":
12581268
if config_dtype == torch.float32:
1259-
# Following the common practice, we use float16 for float32
1260-
# models.
1261-
logger.info("Casting torch.float32 to torch.float16.")
1262-
torch_dtype = torch.float16
1269+
if config.model_type == "gemma2":
1270+
logger.info(
1271+
"For Gemma 2, we downcast float32 to bfloat16 instead "
1272+
"of float16 by default. Please specify `dtype` if you "
1273+
"want to use float16.")
1274+
torch_dtype = torch.bfloat16
1275+
else:
1276+
# Following the common practice, we use float16 for float32
1277+
# models.
1278+
torch_dtype = torch.float16
12631279
else:
12641280
torch_dtype = config_dtype
12651281
else:

vllm/lora/layers.py

+4
Original file line numberDiff line numberDiff line change
@@ -1069,6 +1069,10 @@ def vocab_size(self):
10691069
def scale(self):
10701070
return self.base_layer.scale
10711071

1072+
@property
1073+
def soft_cap(self):
1074+
return self.base_layer.soft_cap
1075+
10721076
@property
10731077
def org_vocab_size(self):
10741078
return self.base_layer.org_vocab_size

vllm/model_executor/layers/layernorm.py

+46
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,49 @@ def extra_repr(self) -> str:
9595
s = f"hidden_size={self.weight.data.size(0)}"
9696
s += f", eps={self.variance_epsilon}"
9797
return s
98+
99+
100+
class GemmaRMSNorm(CustomOp):
101+
"""RMS normalization for Gemma.
102+
103+
Two differences from the above RMSNorm:
104+
1. x * (1 + w) instead of x * w.
105+
2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
106+
"""
107+
108+
def __init__(
109+
self,
110+
hidden_size: int,
111+
eps: float = 1e-6,
112+
) -> None:
113+
super().__init__()
114+
self.weight = nn.Parameter(torch.zeros(hidden_size))
115+
self.variance_epsilon = eps
116+
117+
def forward_native(
118+
self,
119+
x: torch.Tensor,
120+
residual: Optional[torch.Tensor] = None,
121+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
122+
"""PyTorch-native implementation equivalent to forward()."""
123+
orig_dtype = x.dtype
124+
if residual is not None:
125+
x = x + residual
126+
residual = x
127+
128+
x = x.float()
129+
variance = x.pow(2).mean(dim=-1, keepdim=True)
130+
x = x * torch.rsqrt(variance + self.variance_epsilon)
131+
# Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
132+
# See https://github.com/huggingface/transformers/pull/29402
133+
x = x * (1.0 + self.weight.float())
134+
x = x.to(orig_dtype)
135+
return x if residual is None else (x, residual)
136+
137+
def forward_cuda(
138+
self,
139+
x: torch.Tensor,
140+
residual: Optional[torch.Tensor] = None,
141+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
142+
# TODO(woosuk): Implement an optimized kernel for GemmaRMSNorm.
143+
return self.forward_native(x, residual)

vllm/model_executor/layers/logits_processor.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ def __init__(self,
2222
vocab_size: int,
2323
org_vocab_size: Optional[int] = None,
2424
scale: float = 1.0,
25-
logits_as_input: bool = False) -> None:
25+
logits_as_input: bool = False,
26+
soft_cap: Optional[float] = None) -> None:
2627
"""
2728
Args:
2829
scale: A scaling factor to apply to the logits.
@@ -34,6 +35,8 @@ def __init__(self,
3435
self.logits_as_input = logits_as_input
3536
# original vocabulary size (without LoRA).
3637
self.org_vocab_size = org_vocab_size or vocab_size
38+
# Soft cap the logits. Used in Gemma 2.
39+
self.soft_cap = soft_cap
3740

3841
def forward(
3942
self,
@@ -52,6 +55,11 @@ def forward(
5255
logits = self._get_logits(hidden_states, embedding, embedding_bias)
5356

5457
if logits is not None:
58+
if self.soft_cap is not None:
59+
logits = logits / self.soft_cap
60+
logits = torch.tanh(logits)
61+
logits = logits * self.soft_cap
62+
5563
if self.scale != 1.0:
5664
logits *= self.scale
5765

vllm/model_executor/layers/rotary_embedding.py

+10
Original file line numberDiff line numberDiff line change
@@ -610,6 +610,16 @@ def forward(
610610
return query.flatten(-2), key.flatten(-2)
611611

612612

613+
class GemmaRotaryEmbedding(RotaryEmbedding):
614+
615+
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
616+
# https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/gemma/modeling_gemma.py#L107
617+
inv_freq = 1.0 / (base**(
618+
torch.arange(0, self.rotary_dim, 2, dtype=torch.int64).float() /
619+
self.rotary_dim))
620+
return inv_freq
621+
622+
613623
_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
614624

615625

vllm/model_executor/models/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
"DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
2424
"FalconForCausalLM": ("falcon", "FalconForCausalLM"),
2525
"GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
26+
"Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
2627
"GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
2728
"GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
2829
"GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),

0 commit comments

Comments
 (0)