forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 133
[Draft][Gaudi][Model] Qwen2.5-VL optimization #1109
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from 31 commits
Commits
Show all changes
40 commits
Select commit
Hold shift + click to select a range
6358192
Revert "Enabled and optimized GLM-4v-9b on Gaudi (#691)"
imangohari1 e3d0fb2
Merge branch 'HabanaAI:habana_main' into ig/habana_main_85c985e_witho…
imangohari1 a1097dd
fea(): Qwen2.5-vl upgrades. initial commit
imangohari1 4cdc7d7
fea(): Added the changes needed from hpu-extension #61
imangohari1 254ca6b
reverted the hup_model_runner to habana_main and added the qwen2.5-vl…
imangohari1 e8d4c3e
using max_pixels instead of h,w
malkomes 86e65fb
clean up if/else
ssarkar2 4037e04
clean up if-else 2
ssarkar2 de4e2c9
Fix cu_seqlens_now
ssarkar2 35df595
Remove pdb, fix shape
ssarkar2 37311b1
Remove breakpoints
ssarkar2 10dfab9
using max_pixels during warmup
malkomes 2042673
Video inputs ignored for now
ssarkar2 cfb7809
Remove unused return_time
ssarkar2 3cce3bb
Add warning about 112 alignment
ssarkar2 1c4d44c
Move VisionBuckets out to hpu model runner
ssarkar2 5bbfdeb
Create full attention mask outside of VisionTransformer
jiminha 151b3e3
warmup multimoda graph with memory track?
malkomes cd4a9de
Enable profile_run and set disable_tensor_cache=True
jiminha 260e724
we dont need this anymore
malkomes 769cf6f
we dont need b dim
ssarkar2 79f65e0
Fix use_graph to return correctly for multimodal buckets
jiminha cf49203
sort vision buckets
malkomes bbe1571
set input_positions in text to be (3, seq_len) for mrope models
malkomes 73fadb4
linting
malkomes 8aff501
always compute embeddings for qwen2.5vl, even text
malkomes b2c020e
simplify dummy_multi_modal
malkomes 372e793
Add VLLM_GRAPH_MULTIMODAL_PROMPT_RATIO
jiminha a03181d
Merge branch 'habana_main' into ig/qwen2_5-vl_visionTransformer
imangohari1 458b9fa
Clean up some vars
ssarkar2 36213b9
Remove SPLIT flag for Qwen
jiminha 69e3111
Using Qwen2_5_VisionTransformerStaticShape
malkomes bdd279a
no need to change this
malkomes 2137634
Update vllm/model_executor/models/qwen2_5_vl.py
malkomes 14908e1
Update vllm/worker/hpu_model_runner.py
malkomes c0d0207
working on comments
malkomes 96189f9
buckets needs to be multiples of 8
malkomes 7c9cf4c
ops
malkomes 32d7855
Fix the multimodal warmup memory calculation
jiminha 9a0ff19
Fixe print error for einsum
jiminha File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1 +1 @@ | ||
| transformers @ git+https://github.com/huggingface/transformers.git@6b550462139655d488d4c663086a63e98713c6b9 | ||
| transformers @ git+https://github.com/malkomes/transformers.git@e4269f72aebb00b82cc232866e6565597f6ceacf |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,128 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| import pytest | ||
|
|
||
| from vllm.multimodal import MULTIMODAL_REGISTRY | ||
| from vllm.multimodal.utils import cached_get_tokenizer | ||
| # from vllm.model_executor.models.qwen2_5_vl import Qwen2_5_VLImageProcessorForceAlignment | ||
|
|
||
| from ....conftest import _ImageAssets | ||
| from ...utils import build_model_context | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("model_id", ["Qwen/Qwen2.5-VL-3B-Instruct"]) | ||
| # yapf: disable | ||
| @pytest.mark.parametrize( | ||
| ("resize_shape"), [ | ||
| ((112, 112)), | ||
| ((114, 114)), | ||
| ((256, 221)), | ||
| ((1024, 1080)), | ||
| ((784, 1120)), | ||
| ]) | ||
| # yapf: enable | ||
| @pytest.mark.parametrize("num_imgs", [1, 2]) | ||
| def test_processor_force_alignment_resize( | ||
| image_assets: _ImageAssets, | ||
| model_id: str, | ||
| resize_shape: tuple[int, int], | ||
| num_imgs: int, | ||
| ): | ||
| """Ensure images are resized by factor 112.""" | ||
|
|
||
| w, h = resize_shape | ||
| factor = 112 | ||
| h_bar = round(h / factor) * factor | ||
| w_bar = round(w / factor) * factor | ||
| expected_pixels_shape_zero = (w_bar // 14) * (h_bar // 14) | ||
| expected_pixels_shape_one = 1176 | ||
| expected_toks_per_img = expected_pixels_shape_zero // 4 | ||
| mm_processor_kwargs = {} | ||
| #mm_processor_kwargs = {"force_alignment": True} | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. to be removed? |
||
|
|
||
| ctx = build_model_context( | ||
| model_name=model_id, | ||
| tokenizer_name=model_id, | ||
| mm_processor_kwargs=None, | ||
| limit_mm_per_prompt={"image": num_imgs}, | ||
| ) | ||
| tokenizer = cached_get_tokenizer( | ||
| ctx.model_config.tokenizer, | ||
| trust_remote_code=ctx.model_config.trust_remote_code, | ||
| ) | ||
| processor = MULTIMODAL_REGISTRY.create_processor( | ||
| ctx.model_config, | ||
| tokenizer=tokenizer, | ||
| ) | ||
|
|
||
| # Build the image str / prompt based on the number of images we pass | ||
| prompt = "<|vision_start|><|image_pad|><|vision_end|>" * num_imgs | ||
| mm_data = {"image": [image_assets[0].pil_image.resize(resize_shape)] * num_imgs} | ||
|
|
||
| processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs) | ||
|
|
||
| hf_processor = processor.info.get_hf_processor(**mm_processor_kwargs) | ||
|
|
||
| # Ensure we have the right number of placeholders per num_crops size | ||
| image_token_id = tokenizer.convert_tokens_to_ids(hf_processor.image_token) | ||
| img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id) | ||
| pixel_shape = processed_inputs["mm_kwargs"]["pixel_values"].shape | ||
|
|
||
| assert img_tok_count == expected_toks_per_img * num_imgs | ||
| assert pixel_shape[0] == expected_pixels_shape_zero * num_imgs | ||
| assert pixel_shape[1] == expected_pixels_shape_one | ||
| assert pixel_shape[0] % 64 == 0 | ||
|
|
||
| @pytest.mark.parametrize("model_id", ["Qwen/Qwen2.5-VL-3B-Instruct"]) | ||
| # yapf: disable | ||
| @pytest.mark.parametrize( | ||
| ("resize_shape"), [ | ||
| ((110, 112)), | ||
| ((32, 32)), | ||
| ]) | ||
| # yapf: enable | ||
| @pytest.mark.parametrize("num_imgs", [1]) | ||
| def test_processor_force_alignment_resize_to_min_value( | ||
| image_assets: _ImageAssets, | ||
| model_id: str, | ||
| resize_shape: tuple[int, int], | ||
| num_imgs: int, | ||
| ): | ||
| """Ensure processor resizes small images to 112 x 112""" | ||
| expected_pixels_shape_zero = (112 // 14) * (112 // 14) | ||
| expected_pixels_shape_one = 1176 | ||
| expected_toks_per_img = expected_pixels_shape_zero // 4 | ||
|
|
||
| mm_processor_kwargs = {} | ||
|
|
||
| ctx = build_model_context( | ||
| model_name=model_id, | ||
| tokenizer_name=model_id, | ||
| mm_processor_kwargs=None, | ||
| limit_mm_per_prompt={"image": num_imgs}, | ||
| ) | ||
| tokenizer = cached_get_tokenizer( | ||
| ctx.model_config.tokenizer, | ||
| trust_remote_code=ctx.model_config.trust_remote_code, | ||
| ) | ||
| processor = MULTIMODAL_REGISTRY.create_processor( | ||
| ctx.model_config, | ||
| tokenizer=tokenizer, | ||
| ) | ||
|
|
||
| prompt = "<|vision_start|><|image_pad|><|vision_end|>" * num_imgs | ||
| mm_data = {"image": [image_assets[0].pil_image.resize(resize_shape)] * num_imgs} | ||
|
|
||
| processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs) | ||
|
|
||
| hf_processor = processor.info.get_hf_processor(**mm_processor_kwargs) | ||
|
|
||
| # Ensure we have the right number of placeholders per num_crops size | ||
| image_token_id = tokenizer.convert_tokens_to_ids(hf_processor.image_token) | ||
| img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id) | ||
| pixel_shape = processed_inputs["mm_kwargs"]["pixel_values"].shape | ||
|
|
||
| assert img_tok_count == expected_toks_per_img * num_imgs | ||
| assert pixel_shape[0] == expected_pixels_shape_zero * num_imgs | ||
| assert pixel_shape[1] == expected_pixels_shape_one | ||
| assert pixel_shape[0] % 64 == 0 | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should remove this line?