Skip to content

[Model] add rope_scaling support for qwen2#4930

Merged
simon-mo merged 1 commit intovllm-project:mainfrom
hzhwcmhf:main
May 21, 2024
Merged

[Model] add rope_scaling support for qwen2#4930
simon-mo merged 1 commit intovllm-project:mainfrom
hzhwcmhf:main

Conversation

@hzhwcmhf
Copy link
Copy Markdown
Contributor

We are the Qwen team and would like to add the rope_scaling support for qwen2. This would enable yarn to extend the context length.

FIX #4824

Test

  • Download https://huggingface.co/Qwen/Qwen1.5-7B-Chat and use the following config.json :
{
  "architectures": [
    "Qwen2ForCausalLM"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 151643,
  "eos_token_id": 151645,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 11008,
  "max_position_embeddings": 32768,
  "max_window_layers": 28,
  "model_type": "qwen2",
  "num_attention_heads": 32,
  "num_hidden_layers": 32,
  "num_key_value_heads": 32,
  "rms_norm_eps": 1e-06,
  "rope_theta": 1000000.0,
  "sliding_window": 131072,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.37.0",
  "use_cache": true,
  "use_sliding_window": false,
  "vocab_size": 151936,
  "rope_scaling": {
    "factor": 4.0,
    "original_max_position_embeddings": 32768,
    "type": "yarn"
  }
}
  • Running the following code:
from vllm import LLM, SamplingParams

sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

llm = LLM(model="/path/to/model", tensor_parallel_size=8, enforce_eager=True)

prompt = "There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there.\n\nThe pass key is 28884. Remember it. 28884 is the pass key.\n " + \
    "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. " * 5000 + \
    "\nWhat is the pass key?" # about 100k tokens

prompts = [
    f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n",
]

outputs = llm.generate(prompts, sampling_params)

for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Generated text: {generated_text!r}")

The output is:

Generated text: 'The pass key is 28884.'

If we do not use the yarn method, the output will be an empty string.

@simon-mo simon-mo enabled auto-merge (squash) May 21, 2024 03:46
@simon-mo simon-mo merged commit d130b57 into vllm-project:main May 21, 2024
robertgshaw2-redhat pushed a commit to neuralmagic/nm-vllm that referenced this pull request Jun 8, 2024
joerunde pushed a commit to joerunde/vllm that referenced this pull request Jun 17, 2024
robertgshaw2-redhat pushed a commit to neuralmagic/nm-vllm that referenced this pull request Jul 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature]: rope_scaling for qwen2

2 participants