From 8bae38f94e5a847b7c4ea3d17a9081255b1d3630 Mon Sep 17 00:00:00 2001 From: sasha0552 Date: Tue, 14 May 2024 20:28:39 +0000 Subject: [PATCH 1/5] Dynamic RoPE scaling --- vllm/config.py | 7 ++++++- vllm/engine/arg_utils.py | 18 +++++++++++++----- vllm/engine/llm_engine.py | 10 ++++++---- vllm/transformers_utils/config.py | 10 +++++++++- 4 files changed, 34 insertions(+), 11 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 44ed5635f9a3..4208152ec4d5 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -45,6 +45,9 @@ class ModelConfig: code_revision: The specific revision to use for the model code on Hugging Face Hub. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version. + rope_scaling: Dictionary containing the scaling configuration for the + RoPE embeddings. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. tokenizer_revision: The specific tokenizer version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version. @@ -84,6 +87,7 @@ def __init__( seed: int, revision: Optional[str] = None, code_revision: Optional[str] = None, + rope_scaling: Optional[dict] = None, tokenizer_revision: Optional[str] = None, max_model_len: Optional[int] = None, quantization: Optional[str] = None, @@ -102,6 +106,7 @@ def __init__( self.seed = seed self.revision = revision self.code_revision = code_revision + self.rope_scaling = rope_scaling self.tokenizer_revision = tokenizer_revision self.quantization = quantization self.quantization_param_path = quantization_param_path @@ -116,7 +121,7 @@ def __init__( self.skip_tokenizer_init = skip_tokenizer_init self.hf_config = get_config(self.model, trust_remote_code, revision, - code_revision) + code_revision, self.rope_scaling) self.hf_text_config = get_hf_text_config(self.hf_config) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) self.max_model_len = _get_and_verify_max_len(self.hf_text_config, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 1ba424c4eeb1..0a9ec7472fbc 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1,5 +1,6 @@ import argparse import dataclasses +import json from dataclasses import dataclass from typing import List, Optional, Tuple, Union @@ -49,6 +50,7 @@ class EngineArgs: disable_log_stats: bool = False revision: Optional[str] = None code_revision: Optional[str] = None + rope_scaling: Optional[dict] = None tokenizer_revision: Optional[str] = None quantization: Optional[str] = None enforce_eager: bool = False @@ -330,6 +332,11 @@ def add_cli_args( 'None, we assume the model weights are not ' 'quantized and use `dtype` to determine the data ' 'type of the weights.') + parser.add_argument('--rope-scaling', + default=None, + type=json.loads, + help='RoPE scaling configuration in JSON format. ' + 'For example, {"type":"dynamic","factor":2.0}') parser.add_argument('--enforce-eager', action='store_true', help='Always use eager-mode PyTorch. If False, ' @@ -548,11 +555,12 @@ def create_engine_config(self, ) -> EngineConfig: model_config = ModelConfig( self.model, self.tokenizer, self.tokenizer_mode, self.trust_remote_code, self.dtype, self.seed, self.revision, - self.code_revision, self.tokenizer_revision, self.max_model_len, - self.quantization, self.quantization_param_path, - self.enforce_eager, self.max_context_len_to_capture, - self.max_seq_len_to_capture, self.max_logprobs, - self.skip_tokenizer_init, self.served_model_name) + self.code_revision, self.rope_scaling, self.tokenizer_revision, + self.max_model_len, self.quantization, + self.quantization_param_path, self.enforce_eager, + self.max_context_len_to_capture, self.max_seq_len_to_capture, + self.max_logprobs, self.skip_tokenizer_init, + self.served_model_name) cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, self.swap_space, self.kv_cache_dtype, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index f6a5284093c1..60e23d4df15b 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -104,10 +104,11 @@ def __init__( "Initializing an LLM engine (v%s) with config: " "model=%r, speculative_config=%r, tokenizer=%r, " "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, " - "tokenizer_revision=%s, trust_remote_code=%s, dtype=%s, " - "max_seq_len=%d, download_dir=%r, load_format=%s, " - "tensor_parallel_size=%d, disable_custom_all_reduce=%s, " - "quantization=%s, enforce_eager=%s, kv_cache_dtype=%s, " + "rope_scaling=%r, tokenizer_revision=%s, " + "trust_remote_code=%s, dtype=%s, max_seq_len=%d, " + "download_dir=%r, load_format=%s, tensor_parallel_size=%d, " + "disable_custom_all_reduce=%s, quantization=%s, " + "enforce_eager=%s, kv_cache_dtype=%s, " "quantization_param_path=%s, device_config=%s, " "decoding_config=%r, seed=%d, served_model_name=%s)", vllm.__version__, @@ -117,6 +118,7 @@ def __init__( model_config.skip_tokenizer_init, model_config.tokenizer_mode, model_config.revision, + model_config.rope_scaling, model_config.tokenizer_revision, model_config.trust_remote_code, model_config.dtype, diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 1756c91a612f..f36d84dbdf7f 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -2,9 +2,12 @@ from transformers import AutoConfig, PretrainedConfig +from vllm.logger import init_logger from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig, JAISConfig, MPTConfig, RWConfig) +logger = init_logger(__name__) + _CONFIG_REGISTRY: Dict[str, PretrainedConfig] = { "chatglm": ChatGLMConfig, "dbrx": DbrxConfig, @@ -18,7 +21,8 @@ def get_config(model: str, trust_remote_code: bool, revision: Optional[str] = None, - code_revision: Optional[str] = None) -> PretrainedConfig: + code_revision: Optional[str] = None, + rope_scaling: Optional[dict] = None) -> PretrainedConfig: try: config = AutoConfig.from_pretrained( model, @@ -41,6 +45,10 @@ def get_config(model: str, config = config_class.from_pretrained(model, revision=revision, code_revision=code_revision) + if rope_scaling is not None: + logger.info("Updating rope_scaling from %r to %r", + getattr(config, "rope_scaling", None), rope_scaling) + config.update({"rope_scaling": rope_scaling}) return config From 1635321f8723da3f218a10e9d55b7260bb13b193 Mon Sep 17 00:00:00 2001 From: sasha0552 Date: Tue, 21 May 2024 16:17:32 +0000 Subject: [PATCH 2/5] Add rope scaling test --- tests/test_config.py | 53 +++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 52 insertions(+), 1 deletion(-) diff --git a/tests/test_config.py b/tests/test_config.py index 19db10630bba..50518bf00a75 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -36,4 +36,55 @@ def test_get_sliding_window(): assert mistral_model_config.get_sliding_window() is None mistral_model_config.hf_config.sliding_window = TEST_SLIDING_WINDOW - assert mistral_model_config.get_sliding_window() == TEST_SLIDING_WINDOW \ No newline at end of file + assert mistral_model_config.get_sliding_window() == TEST_SLIDING_WINDOW + + +def test_rope_scaling(): + TEST_ROPE_SCALING = {"type": "dynamic", "factor": 2.0} + LONGCHAT_ROPE_SCALING = {"type": "linear", "factor": 8.0} + + llama_model_config = ModelConfig( + "meta-llama/Meta-Llama-3-8B-Instruct", + "meta-llama/Meta-Llama-3-8B-Instruct", + tokenizer_mode="auto", + trust_remote_code=False, + dtype="float16", + seed=0, + ) + assert llama_model_config.rope_scaling is None + assert llama_model_config.max_model_len == 8192 + + llama_model_config = ModelConfig( + "meta-llama/Meta-Llama-3-8B-Instruct", + "meta-llama/Meta-Llama-3-8B-Instruct", + tokenizer_mode="auto", + trust_remote_code=False, + dtype="float16", + seed=0, + rope_scaling=TEST_ROPE_SCALING, + ) + assert llama_model_config.rope_scaling == TEST_ROPE_SCALING + assert llama_model_config.max_model_len == 16384 + + longchat_model_config = ModelConfig( + "lmsys/longchat-13b-16k", + "lmsys/longchat-13b-16k", + tokenizer_mode="auto", + trust_remote_code=False, + dtype="float16", + seed=0, + ) + assert longchat_model_config.rope_scaling == LONGCHAT_ROPE_SCALING + assert longchat_model_config.max_model_len == 131072 + + longchat_model_config = ModelConfig( + "lmsys/longchat-13b-16k", + "lmsys/longchat-13b-16k", + tokenizer_mode="auto", + trust_remote_code=False, + dtype="float16", + seed=0, + rope_scaling=TEST_ROPE_SCALING, + ) + assert longchat_model_config.rope_scaling == TEST_ROPE_SCALING + assert longchat_model_config.max_model_len == 32768 From 7a43a7775761f14740d613e878e0c8d97bb6e4a5 Mon Sep 17 00:00:00 2001 From: sasha0552 Date: Tue, 21 May 2024 17:18:37 +0000 Subject: [PATCH 3/5] Fix test --- tests/test_config.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/test_config.py b/tests/test_config.py index 50518bf00a75..16585391c825 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -51,7 +51,7 @@ def test_rope_scaling(): dtype="float16", seed=0, ) - assert llama_model_config.rope_scaling is None + assert getattr(llama_model_config.hf_config, "rope_scaling", None) is None assert llama_model_config.max_model_len == 8192 llama_model_config = ModelConfig( @@ -63,7 +63,8 @@ def test_rope_scaling(): seed=0, rope_scaling=TEST_ROPE_SCALING, ) - assert llama_model_config.rope_scaling == TEST_ROPE_SCALING + assert getattr(llama_model_config.hf_config, "rope_scaling", + None) == TEST_ROPE_SCALING assert llama_model_config.max_model_len == 16384 longchat_model_config = ModelConfig( @@ -74,7 +75,8 @@ def test_rope_scaling(): dtype="float16", seed=0, ) - assert longchat_model_config.rope_scaling == LONGCHAT_ROPE_SCALING + assert getattr(longchat_model_config.hf_config, "rope_scaling", + None) == LONGCHAT_ROPE_SCALING assert longchat_model_config.max_model_len == 131072 longchat_model_config = ModelConfig( @@ -86,5 +88,6 @@ def test_rope_scaling(): seed=0, rope_scaling=TEST_ROPE_SCALING, ) - assert longchat_model_config.rope_scaling == TEST_ROPE_SCALING + assert getattr(longchat_model_config.hf_config, "rope_scaling", + None) == TEST_ROPE_SCALING assert longchat_model_config.max_model_len == 32768 From bb81a2e6f2120852b1d7148de50c21a3674a275a Mon Sep 17 00:00:00 2001 From: sasha0552 Date: Tue, 21 May 2024 18:04:05 +0000 Subject: [PATCH 4/5] Fix test --- tests/test_config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_config.py b/tests/test_config.py index 16585391c825..6bc51a53dc07 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -77,7 +77,7 @@ def test_rope_scaling(): ) assert getattr(longchat_model_config.hf_config, "rope_scaling", None) == LONGCHAT_ROPE_SCALING - assert longchat_model_config.max_model_len == 131072 + assert longchat_model_config.max_model_len == 16384 longchat_model_config = ModelConfig( "lmsys/longchat-13b-16k", @@ -90,4 +90,4 @@ def test_rope_scaling(): ) assert getattr(longchat_model_config.hf_config, "rope_scaling", None) == TEST_ROPE_SCALING - assert longchat_model_config.max_model_len == 32768 + assert longchat_model_config.max_model_len == 4096 From 22c76cf1ce771728a0babc4b59930d922d7f0b1b Mon Sep 17 00:00:00 2001 From: sasha0552 Date: Tue, 21 May 2024 19:46:06 +0000 Subject: [PATCH 5/5] Minor fix --- vllm/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index 4208152ec4d5..3256c1196791 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -121,7 +121,7 @@ def __init__( self.skip_tokenizer_init = skip_tokenizer_init self.hf_config = get_config(self.model, trust_remote_code, revision, - code_revision, self.rope_scaling) + code_revision, rope_scaling) self.hf_text_config = get_hf_text_config(self.hf_config) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) self.max_model_len = _get_and_verify_max_len(self.hf_text_config,