Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Fix new Llama3.1 GGUF model loading #7269

Merged
merged 2 commits into from
Aug 8, 2024

Conversation

Isotr0py
Copy link
Collaborator

@Isotr0py Isotr0py commented Aug 7, 2024

FILL IN THE PR DESCRIPTION HERE

FIX #7268 (link existing issues this PR will resolve)

  • After the investigation about llama.cpp PR-8676, which introduced the rope_freqs.weight, seems that rope_freqs.weight is not an exact model weight but a tensor with rope factor for llama.cpp inference improvement.
  • As a result, I decided to ignore these tensors like rope_freqs.weight not required by exact model state dict.

BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE


PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

Copy link

github-actions bot commented Aug 7, 2024

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which consists a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of default ones by unblocking the steps in your fast-check build on Buildkite UI.

Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge).

To run full CI, you can do one of these:

  • Comment /ready on the PR
  • Add ready label to the PR
  • Enable auto-merge.

🚀

@Isotr0py
Copy link
Collaborator Author

Isotr0py commented Aug 7, 2024

/ready

@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Aug 7, 2024
@mgoin
Copy link
Collaborator

mgoin commented Aug 7, 2024

Could you run an eval on a Llama 3.1 GGUF checkpoint to make sure that it gives a reasonable result?

For example this is what is it like to run a gsm8k eval on Llama 3 GGUF:

pip install lm-eval
bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m Meta-Llama-3-8B-Instruct.Q4_K_M.gguf -b 32 -l 1000 -f 5 -t 1

INFO 08-07 15:58:30 config.py:1451] Downcasting torch.float32 to torch.float16.
WARNING 08-07 15:58:30 config.py:254] gguf quantization is not fully optimized yet. The speed can be slower than non-quantized models.
INFO 08-07 15:58:33 llm_engine.py:174] Initializing an LLM engine (v0.5.3.post1) with config: model='Meta-Llama-3-8B-Instruct.Q4_K_M.gguf', speculative_config=None, tokenizer='Meta-Llama-3-8B-Instruct.Q4_K_M.gguf', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.float16, max_seq_len=4096, download_dir=None, load_format=LoadFormat.GGUF, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=gguf, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None), seed=1234, served_model_name=Meta-Llama-3-8B-Instruct.Q4_K_M.gguf, use_v2_block_manager=False, enable_prefix_caching=False)
You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565 - if you loaded a llama tokenizer from a GGUF file you can ignore this message.
INFO 08-07 15:58:58 ray_gpu_executor.py:117] use_ray_spmd_worker: False
INFO 08-07 15:58:58 ray_gpu_executor.py:120] driver_ip: 216.81.245.69
INFO 08-07 15:59:04 model_runner.py:720] Starting to load model Meta-Llama-3-8B-Instruct.Q4_K_M.gguf...
INFO 08-07 15:59:19 model_runner.py:732] Loading model weights took 4.7085 GB
INFO 08-07 15:59:22 distributed_gpu_executor.py:56] # GPU blocks: 33275, # CPU blocks: 2048
INFO 08-07 15:59:23 model_runner.py:1024] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
INFO 08-07 15:59:23 model_runner.py:1028] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
INFO 08-07 15:59:40 model_runner.py:1225] Graph capturing finished in 16 secs.
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:02<00:00, 363.41it/s]
Running generate_until requests: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [15:35<00:00,  1.07it/s]
vllm (pretrained=Meta-Llama-3-8B-Instruct.Q4_K_M.gguf,tensor_parallel_size=1,distributed_executor_backend=ray,trust_remote_code=true,max_model_len=4096), gen_kwargs: (None), limit: 1000.0, num_fewshot: 5, batch_size: 32
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.613|±  |0.0154|
|     |       |strict-match    |     5|exact_match|↑  |0.732|±  |0.0140|

@Isotr0py
Copy link
Collaborator Author

Isotr0py commented Aug 8, 2024

OK, I will run the evaluation on an idle GPU. This may cost some time because flash-attention is unavailable on that GPU.

@RodriMora
Copy link

I just run it with an external lm eval tool as using the .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh was extremely slow (>100s/it) vs 4s/it running it this way:

lm eval:
lm_eval --model local-completions --tasks gsm8k --model_args model=/home/ubuntuai/models/Meta-Llama-3.1-8B-Instruct,base_url=http://localhost:5001/v1/completions,num_concurrent=1,max_retries=3,tokenized_requests=False

vLLM:
CUDA_VISIBLE_DEVICES=1 vllm serve --host 0.0.0.0 --port 5001 --gpu-memory-utilization 0.9 ~/models/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf --served-model-name /home/ubuntuai/models/Meta-Llama-3.1-8B-Instruct -tp 1 --max-model-len 8192

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.7506|±  |0.0119|
|     |       |strict-match    |     5|exact_match|↑  |0.7081|±  |0.0125|

@mgoin
Copy link
Collaborator

mgoin commented Aug 8, 2024

@RodriMora maybe this was some issue with batching perf.. either way, accuracy looks good so approving, thank you!

Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@simon-mo simon-mo merged commit 8334c39 into vllm-project:main Aug 8, 2024
47 of 49 checks passed
@Isotr0py Isotr0py deleted the gguf-llama3.1 branch August 9, 2024 02:30
sfc-gh-mkeralapura pushed a commit to sfc-gh-mkeralapura/vllm that referenced this pull request Aug 12, 2024
kylesayrs pushed a commit to neuralmagic/vllm that referenced this pull request Aug 17, 2024
fialhocoelho pushed a commit to opendatahub-io/vllm that referenced this pull request Aug 22, 2024
Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
KuntaiDu pushed a commit to KuntaiDu/vllm that referenced this pull request Nov 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Bug]: model loading failed for some Llama3.1 GGUF model
5 participants