|
14 | 14 | from vllm.tracing import is_otel_installed
|
15 | 15 | from vllm.transformers_utils.config import get_config, get_hf_text_config
|
16 | 16 | from vllm.utils import (cuda_device_count_stateless, get_cpu_memory, is_cpu,
|
17 |
| - is_hip, is_neuron, is_tpu, is_xpu, |
| 17 | + is_hip, is_neuron, is_tpu, is_xpu, print_warning_once, |
18 | 18 | update_environment_variables)
|
19 | 19 |
|
20 | 20 | if TYPE_CHECKING:
|
@@ -141,6 +141,17 @@ def __init__(
|
141 | 141 | code_revision, rope_scaling, rope_theta)
|
142 | 142 | self.hf_text_config = get_hf_text_config(self.hf_config)
|
143 | 143 | self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
|
| 144 | + |
| 145 | + if (not self.disable_sliding_window |
| 146 | + and self.hf_text_config.model_type == "gemma2" |
| 147 | + and self.hf_text_config.sliding_window is not None): |
| 148 | + print_warning_once( |
| 149 | + "Gemma 2 uses sliding window attention for every odd layer, " |
| 150 | + "which is currently not supported by vLLM. Disabling sliding " |
| 151 | + "window and capping the max length to the sliding window size " |
| 152 | + f"({self.hf_text_config.sliding_window}).") |
| 153 | + self.disable_sliding_window = True |
| 154 | + |
144 | 155 | self.max_model_len = _get_and_verify_max_len(
|
145 | 156 | hf_config=self.hf_text_config,
|
146 | 157 | max_model_len=max_model_len,
|
@@ -257,8 +268,7 @@ def verify_with_parallel_config(
|
257 | 268 | "BitAndBytes quantization with TP or PP is not supported yet.")
|
258 | 269 |
|
259 | 270 | def get_hf_config_sliding_window(self) -> Optional[int]:
|
260 |
| - """Get the sliding window size, or None if disabled. |
261 |
| - """ |
| 271 | + """Get the sliding window size, or None if disabled.""" |
262 | 272 |
|
263 | 273 | # Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in
|
264 | 274 | # addition to sliding window size. We check if that field is present
|
@@ -1256,10 +1266,16 @@ def _get_and_verify_dtype(
|
1256 | 1266 | dtype = dtype.lower()
|
1257 | 1267 | if dtype == "auto":
|
1258 | 1268 | if config_dtype == torch.float32:
|
1259 |
| - # Following the common practice, we use float16 for float32 |
1260 |
| - # models. |
1261 |
| - logger.info("Casting torch.float32 to torch.float16.") |
1262 |
| - torch_dtype = torch.float16 |
| 1269 | + if config.model_type == "gemma2": |
| 1270 | + logger.info( |
| 1271 | + "For Gemma 2, we downcast float32 to bfloat16 instead " |
| 1272 | + "of float16 by default. Please specify `dtype` if you " |
| 1273 | + "want to use float16.") |
| 1274 | + torch_dtype = torch.bfloat16 |
| 1275 | + else: |
| 1276 | + # Following the common practice, we use float16 for float32 |
| 1277 | + # models. |
| 1278 | + torch_dtype = torch.float16 |
1263 | 1279 | else:
|
1264 | 1280 | torch_dtype = config_dtype
|
1265 | 1281 | else:
|
|
0 commit comments