From 635819205ee783056d7ed10af036ea12a1984a7b Mon Sep 17 00:00:00 2001 From: Iman Gohari Date: Tue, 8 Apr 2025 21:57:27 +0000 Subject: [PATCH 01/38] Revert "Enabled and optimized GLM-4v-9b on Gaudi (#691)" This reverts commit c0e696bf7f5b5569598fdd354a30250f362ac4ef. --- vllm/model_executor/models/utils.py | 34 ----------------------------- 1 file changed, 34 deletions(-) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 9925fe16d39c..fff4be34ddbe 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -14,7 +14,6 @@ from vllm.logger import init_logger from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal import MultiModalPlaceholderMap, NestedTensors -from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.utils import is_pin_memory_available @@ -444,14 +443,6 @@ def merge_multimodal_embeddings( Note: This updates ``inputs_embeds`` in place. """ - if current_platform.is_hpu(): - return _hpu_merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - placeholder_token_id, - ) - if isinstance(placeholder_token_id, list): placeholder_token_id = torch.tensor(placeholder_token_id, device=input_ids.device) @@ -650,28 +641,3 @@ def extract_layer_index(layer_name: str) -> int: assert len(int_vals) == 1, (f"layer name {layer_name} should" " only contain one integer") return int_vals[0] - - -def _hpu_merge_multimodal_embeddings( - input_ids: torch.Tensor, - inputs_embeds: torch.Tensor, - multimodal_embeddings: NestedTensors, - placeholder_token_id: torch.tensor, -) -> torch.Tensor: - """ - Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the - positions in ``inputs_embeds`` corresponding to placeholder tokens in - ``input_ids``. - merge_multimodal_embeddings on HPU to avoid dynamicity. - Note: - This updates ``inputs_embeds`` in place. - """ - batch_size, seq_length, hidden_size = inputs_embeds.shape - inputs_embeds = inputs_embeds.reshape(-1, hidden_size) - multimodal_embeddings = multimodal_embeddings.reshape(-1, hidden_size) - placeholder_token_id = torch.tensor(placeholder_token_id, - device=input_ids.device) - mask = torch.isin(input_ids.reshape(-1), placeholder_token_id) - inputs_embeds.index_put_((mask, ), multimodal_embeddings) - inputs_embeds = inputs_embeds.reshape(batch_size, seq_length, hidden_size) - return inputs_embeds From a1097dd8e9566e08174dccdbd0460d6a37cca0ec Mon Sep 17 00:00:00 2001 From: Iman Gohari Date: Mon, 14 Apr 2025 18:21:34 +0000 Subject: [PATCH 02/38] fea(): Qwen2.5-vl upgrades. initial commit --- requirements-hpu-qwen2_5_vl.txt | 2 +- .../multimodal/processing/test_qwen2_5_vl.py | 128 +++ vllm/model_executor/models/qwen2_5_vl.py | 251 +++++- vllm/model_executor/models/utils.py | 19 +- vllm/worker/hpu_model_runner.py | 845 +++++++++++------- 5 files changed, 874 insertions(+), 371 deletions(-) create mode 100644 tests/models/multimodal/processing/test_qwen2_5_vl.py diff --git a/requirements-hpu-qwen2_5_vl.txt b/requirements-hpu-qwen2_5_vl.txt index 21bcfbfe0b11..0ca709c1c926 100644 --- a/requirements-hpu-qwen2_5_vl.txt +++ b/requirements-hpu-qwen2_5_vl.txt @@ -1 +1 @@ -transformers @ git+https://github.com/huggingface/transformers.git@6b550462139655d488d4c663086a63e98713c6b9 +transformers @ git+https://github.com/malkomes/transformers.git@e4269f72aebb00b82cc232866e6565597f6ceacf diff --git a/tests/models/multimodal/processing/test_qwen2_5_vl.py b/tests/models/multimodal/processing/test_qwen2_5_vl.py new file mode 100644 index 000000000000..c4b85bfcec85 --- /dev/null +++ b/tests/models/multimodal/processing/test_qwen2_5_vl.py @@ -0,0 +1,128 @@ +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.utils import cached_get_tokenizer +# from vllm.model_executor.models.qwen2_5_vl import Qwen2_5_VLImageProcessorForceAlignment + +from ....conftest import _ImageAssets +from ...utils import build_model_context + + +@pytest.mark.parametrize("model_id", ["Qwen/Qwen2.5-VL-3B-Instruct"]) +# yapf: disable +@pytest.mark.parametrize( + ("resize_shape"), [ + ((112, 112)), + ((114, 114)), + ((256, 221)), + ((1024, 1080)), + ((784, 1120)), + ]) +# yapf: enable +@pytest.mark.parametrize("num_imgs", [1, 2]) +def test_processor_force_alignment_resize( + image_assets: _ImageAssets, + model_id: str, + resize_shape: tuple[int, int], + num_imgs: int, +): + """Ensure images are resized by factor 112.""" + + w, h = resize_shape + factor = 112 + h_bar = round(h / factor) * factor + w_bar = round(w / factor) * factor + expected_pixels_shape_zero = (w_bar // 14) * (h_bar // 14) + expected_pixels_shape_one = 1176 + expected_toks_per_img = expected_pixels_shape_zero // 4 + mm_processor_kwargs = {} + #mm_processor_kwargs = {"force_alignment": True} + + ctx = build_model_context( + model_name=model_id, + tokenizer_name=model_id, + mm_processor_kwargs=None, + limit_mm_per_prompt={"image": num_imgs}, + ) + tokenizer = cached_get_tokenizer( + ctx.model_config.tokenizer, + trust_remote_code=ctx.model_config.trust_remote_code, + ) + processor = MULTIMODAL_REGISTRY.create_processor( + ctx.model_config, + tokenizer=tokenizer, + ) + + # Build the image str / prompt based on the number of images we pass + prompt = "<|vision_start|><|image_pad|><|vision_end|>" * num_imgs + mm_data = {"image": [image_assets[0].pil_image.resize(resize_shape)] * num_imgs} + + processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs) + + hf_processor = processor.info.get_hf_processor(**mm_processor_kwargs) + + # Ensure we have the right number of placeholders per num_crops size + image_token_id = tokenizer.convert_tokens_to_ids(hf_processor.image_token) + img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id) + pixel_shape = processed_inputs["mm_kwargs"]["pixel_values"].shape + + assert img_tok_count == expected_toks_per_img * num_imgs + assert pixel_shape[0] == expected_pixels_shape_zero * num_imgs + assert pixel_shape[1] == expected_pixels_shape_one + assert pixel_shape[0] % 64 == 0 + +@pytest.mark.parametrize("model_id", ["Qwen/Qwen2.5-VL-3B-Instruct"]) +# yapf: disable +@pytest.mark.parametrize( + ("resize_shape"), [ + ((110, 112)), + ((32, 32)), + ]) +# yapf: enable +@pytest.mark.parametrize("num_imgs", [1]) +def test_processor_force_alignment_resize_to_min_value( + image_assets: _ImageAssets, + model_id: str, + resize_shape: tuple[int, int], + num_imgs: int, +): + """Ensure processor resizes small images to 112 x 112""" + expected_pixels_shape_zero = (112 // 14) * (112 // 14) + expected_pixels_shape_one = 1176 + expected_toks_per_img = expected_pixels_shape_zero // 4 + + mm_processor_kwargs = {} + + ctx = build_model_context( + model_name=model_id, + tokenizer_name=model_id, + mm_processor_kwargs=None, + limit_mm_per_prompt={"image": num_imgs}, + ) + tokenizer = cached_get_tokenizer( + ctx.model_config.tokenizer, + trust_remote_code=ctx.model_config.trust_remote_code, + ) + processor = MULTIMODAL_REGISTRY.create_processor( + ctx.model_config, + tokenizer=tokenizer, + ) + + prompt = "<|vision_start|><|image_pad|><|vision_end|>" * num_imgs + mm_data = {"image": [image_assets[0].pil_image.resize(resize_shape)] * num_imgs} + + processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs) + + hf_processor = processor.info.get_hf_processor(**mm_processor_kwargs) + + # Ensure we have the right number of placeholders per num_crops size + image_token_id = tokenizer.convert_tokens_to_ids(hf_processor.image_token) + img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id) + pixel_shape = processed_inputs["mm_kwargs"]["pixel_values"].shape + + assert img_tok_count == expected_toks_per_img * num_imgs + assert pixel_shape[0] == expected_pixels_shape_zero * num_imgs + assert pixel_shape[1] == expected_pixels_shape_one + assert pixel_shape[0] % 64 == 0 diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 0b3f9014568b..ae43a5d77628 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -61,6 +61,10 @@ from vllm.sequence import IntermediateTensors from vllm.transformers_utils.config import uses_mrope +from habana_frameworks.torch.hpex.kernels import FusedSDPA +import os +import habana_frameworks.torch.core as htcore + from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP from .qwen2_vl import Qwen2VLDummyInputsBuilder as Qwen2_5_VLDummyInputsBuilder from .qwen2_vl import (Qwen2VLMultiModalProcessor, Qwen2VLProcessingInfo, @@ -75,6 +79,45 @@ # === Vision Inputs === # +class AttentionLongSequence: + @staticmethod + def forward(q, k, v, mask, q_block_size): + """ + Support long sequence at prompt phase + """ + q_len = q.size(-2) + assert q_len % q_block_size == 0 + q_tiles = (q_len // q_block_size) if (q_len % q_block_size == 0) else math.ceil(q_len / q_block_size) + #q_padding = q_tiles * q_block_size - q_len + #q = F.pad(q, (0, 0, 0, q_padding), "constant", 0) + #if mask is not None: + # mask = F.pad(mask, (0, 0, 0, q_padding), "constant", -10000.0) + attn_output = torch.zeros_like(q) + + for i in range(q_tiles): + s, e = i * q_block_size, (i + 1) * q_block_size + row_q = q[:, :, s:e, :] + row_mask = mask[:, :, s:e, :] + attn_output[:, :, s:e, :] = FusedSDPA.apply(row_q, k, v, row_mask, 0.0, False, None) + #TODO: markstep every 10th layer, didn't experiment which one is optimal number. + #10,50,100 shows simliar result, without this, we see the program hangs for multiple prompts(with larger images) + if i % 75 == 0: + htcore.mark_step() + return attn_output + +def create_block_diagonal_attention_mask_outerprod(indices): + maxsize = indices[-1] + range_to_max_for_each_img = torch.arange(maxsize, device=indices.device).unsqueeze(0).repeat(indices.shape[0]-1,1) + yy = range_to_max_for_each_img < indices[1:].unsqueeze(1) + zz = range_to_max_for_each_img >= indices[:-1].unsqueeze(1) + xx = torch.logical_and(yy, zz) + # can reduce sum externally or as batchmatmul + res = torch.sum(torch.einsum('bi,bj->bij', xx, xx), dim=0) + #res = torch.einsum('bi,bj->ij', xx.float(), xx.float()) + return res.bool() + +def expand_to_max(indices, max_num_images): + return torch.nn.functional.pad(indices, (0, max_num_images-indices.shape[0]), value=indices[-1]) class Qwen2_5_VLImagePixelInputs(TypedDict): type: Literal["pixel_values"] @@ -216,6 +259,7 @@ def __init__( self.num_attention_heads_per_partition = dist_utils.divide( num_heads, self.tp_size) + self.qkv = ColumnParallelLinear(input_size=embed_dim, output_size=3 * projection_size, quant_config=quant_config, @@ -260,7 +304,7 @@ def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: def forward( self, x: torch.Tensor, - cu_seqlens: torch.Tensor, + cu_seqlens: Optional[torch.Tensor], rotary_pos_emb: torch.Tensor, ) -> torch.Tensor: # [s, b, c] --> [s, b, head * 3 * head_dim] @@ -304,26 +348,43 @@ def forward( b=batch_size) elif self.attn_backend == _Backend.TORCH_SDPA: # Execute attention entry by entry for speed & less VRAM. - outputs = [] - for i in range(1, len(cu_seqlens)): - start_idx = cu_seqlens[i - 1] - end_idx = cu_seqlens[i] - q_i = q[:, start_idx:end_idx] - k_i = k[:, start_idx:end_idx] - v_i = v[:, start_idx:end_idx] - q_i, k_i, v_i = (rearrange(x, "b s h d -> b h s d") - for x in [q_i, k_i, v_i]) - if is_hpu: - from habana_frameworks.torch.hpex.kernels import FusedSDPA - output_i = FusedSDPA.apply(q_i, k_i, v_i, None, 0.0) + if cu_seqlens is None: + outputs = [] + cu_seqlens = list(range(0, x.shape[0]+1, 64)) # assuming x%64=0 (image is 112 aligned in both h/w dims) + for i in range(1, len(cu_seqlens)): + start_idx = cu_seqlens[i - 1] + end_idx = cu_seqlens[i] + q_i = q[:, start_idx:end_idx] + k_i = k[:, start_idx:end_idx] + v_i = v[:, start_idx:end_idx] + q_i, k_i, v_i = (rearrange(x, "b s h d -> b h s d") + for x in [q_i, k_i, v_i]) + if is_hpu: + output_i = FusedSDPA.apply(q_i, k_i, v_i, None, 0.0) + else: + output_i = F.scaled_dot_product_attention(q_i, + k_i, + v_i, + dropout_p=0.0) + output_i = rearrange(output_i, "b h s d -> b s h d ") + outputs.append(output_i) + context_layer = torch.cat(outputs, dim=1) + else: + fullatt_block_attn_mask = create_block_diagonal_attention_mask_outerprod(cu_seqlens) + q1, k1, v1 = (rearrange(x, "b s h d -> b h s d")for x in [q, k, v]) + + + (batch_size, n_heads, seq_len_N_t, head_dim_qk) = q1.shape + (batch_size, n_heads, seq_len_N_s, head_dim_qk) = k1.shape + mask_shape = (batch_size, 1, seq_len_N_t, seq_len_N_s) + attn_mask = fullatt_block_attn_mask.reshape(batch_size, 1, seq_len_N_t, seq_len_N_s, -1)[:, :, :, :, 0] + assert attn_mask.shape == mask_shape + + if q1.shape[2] <= 6400: # this crossover point should be measured + fused_out = FusedSDPA.apply(q1, k1, v1, attn_mask, 0.0) # Bx1xNxN else: - output_i = F.scaled_dot_product_attention(q_i, - k_i, - v_i, - dropout_p=0.0) - output_i = rearrange(output_i, "b h s d -> b s h d ") - outputs.append(output_i) - context_layer = torch.cat(outputs, dim=1) + fused_out = AttentionLongSequence.forward(q1, k1, v1, attn_mask, 64) + context_layer = rearrange(fused_out, "b h s d -> b s h d ") elif self.attn_backend == _Backend.XFORMERS: from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalMask @@ -370,7 +431,8 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.mlp") - def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor, + def forward(self, x: torch.Tensor, #cu_seqlens: torch.Tensor, + cu_seqlens: Optional[torch.Tensor], rotary_pos_emb: torch.Tensor) -> torch.Tensor: x = x + self.attn(self.norm1(x), cu_seqlens=cu_seqlens, @@ -553,7 +615,9 @@ def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: self.spatial_merge_size, w // self.spatial_merge_size, self.spatial_merge_size, - ).permute(0, 2, 1, 3).flatten() + ).permute( + 0, 2, 1, + 3).flatten() wpos_ids = wpos_ids.reshape( h // self.spatial_merge_size, self.spatial_merge_size, @@ -603,11 +667,7 @@ def get_window_index(self, grid_thw): window_index = torch.cat(window_index, dim=0) return window_index, cu_window_seqlens - def forward( - self, - x: torch.Tensor, - grid_thw: torch.Tensor, - ) -> torch.Tensor: + def pre_attn(self, x: torch.Tensor, grid_thw: torch.Tensor): # patchify hidden_states = x.to(device=self.device, dtype=self.dtype) hidden_states = self.patch_embed(hidden_states) @@ -650,26 +710,40 @@ def remove_duplicates_cpu(a): seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) rotary_pos_emb = rotary_pos_emb[window_index, :, :] rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) - # compute cu_seqlens cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( dim=0, dtype=torch.int32) + return hidden_states, rotary_pos_emb, cu_seqlens, cu_window_seqlens, window_index + + def forward( + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor) -> torch.Tensor: + assert x.shape[0] == cu_seqlens[-1] == rotary_pos_emb.shape[0] cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0) - # transformers + hidden_states = x.unsqueeze(1) hidden_states = hidden_states.unsqueeze(1) for layer_num, blk in enumerate(self.blocks): if layer_num in self.fullatt_block_indexes: + #fullatt_block_attn_mask = None cu_seqlens_now = cu_seqlens - else: + cu_seqlens_now = None cu_seqlens_now = cu_window_seqlens hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens_now, rotary_pos_emb=rotary_pos_emb) # adapter + + return hidden_states + + def post_attn(self, hidden_states: torch.Tensor, + window_index: torch.Tensor): hidden_states = self.merger(hidden_states) reverse_indices = torch.argsort(window_index) + hidden_states = hidden_states[reverse_indices, :] return hidden_states @@ -828,6 +902,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) + envvar = os.environ.get('FIXED_MULTIMODAL_BUCKETS', "") + if envvar == "": + self.FIXED_MULTIMODAL_BUCKETS = [1600, 3200, 4800, 6400] # add 768 a small bucket maybe? + else: + self.FIXED_MULTIMODAL_BUCKETS = [int(i) for i in envvar.split(',')] + assert all([k%64 == 0 for k in self.FIXED_MULTIMODAL_BUCKETS]), f"FIXED_MULTIMODAL_BUCKETS should all be multiples of 64, but was {self.FIXED_MULTIMODAL_BUCKETS}" @cached_property def sampler(self): @@ -903,6 +983,7 @@ def _parse_and_validate_video_input( video_grid_thw = kwargs.pop("video_grid_thw", None) second_per_grid_ts = kwargs.pop("second_per_grid_ts", None) + if pixel_values_videos is None and video_embeds is None: return None @@ -932,6 +1013,42 @@ def _parse_and_validate_video_input( type="video_embeds", video_embeds=video_embeds, video_grid_thw=video_grid_thw) + def _get_multimodal_bucket(self, curr_num_image_patches): + for mm_bucket in self.FIXED_MULTIMODAL_BUCKETS: + if curr_num_image_patches <= mm_bucket: + return mm_bucket + self.FIXED_MULTIMODAL_BUCKETS += [curr_num_image_patches] # a shape larger than any that was compiled before. its gonna be compiled now, so save it for the future + return curr_num_image_patches + + def pad_multimodal_data(self, pixel_values, image_grid_thw): + assert pixel_values.shape[ + 0] % 64 == 0, '[testing version] needs 64 aligned resolution' + + desired_number_of_pixels = self._get_multimodal_bucket(pixel_values.shape[0]) + padding_len = desired_number_of_pixels - pixel_values.shape[0] + if padding_len <= 0: + #breakpoint() + return pixel_values, image_grid_thw + + logger.info( + f"[MM_BUCKETING] Padding current number pixel {pixel_values.shape[0]} to {desired_number_of_pixels}" + ) + # needs to make sure padding_len is even + assert padding_len % 64 == 0, '[testing version] padding needs to be multiple of 64' + + constant_value = -100 + pixel_values = torch.cat([ + pixel_values, + torch.ones((padding_len, pixel_values.shape[1]), device=pixel_values.device) * constant_value + ]) + + image_grid_thw = torch.cat( + [image_grid_thw, + torch.tensor([[1, 8, padding_len // 8]], device=image_grid_thw.device)]) + + assert image_grid_thw.prod(-1).sum() == desired_number_of_pixels + return pixel_values, image_grid_thw + def _process_image_input( self, @@ -943,7 +1060,67 @@ def _process_image_input( if image_input["type"] == "image_embeds": image_embeds = image_input["image_embeds"].type(self.visual.dtype) else: - pixel_values = image_input["pixel_values"].type(self.visual.dtype) + + if True: + ''' + go thru grid_thw + say grid_thw is 1,16,16 and 1,128,128 + say u have 2 buckets: 512 and 16384 + + slice pixel_values at 16*16 = 256 (1st img) and attach a new "image" to it, to pad it up to 512 + attach a new + ''' + + offset = 0 + # right now we do 1 img at a time, but if we have multiple small images we could pack them in together + # Like say if I have image = 224, 224, 6400, and my buckets are: 1024, 6400 + # instead of padding 224->1024 and 224->1024, we can pack both 224 into 1 and send it to 1024 + results = [] + # During warmup: self.model.visual_warmup_times isnt set, so we can do it 1-by-1 + # after warmup we need to check "visual_warmup_times" and we can batch based on that + # Note that sometimes we may recompile, in which case "_get_multimodal_bucket" will return a larger number + # however we will not have time for that larger size in "visual_warmup_times" + # so after that our policy will be: + # if size within original buckets attempt coalescing within original buckets + # if size is larger, only then use a already precompiled non-original bucket + for img_idx in range(grid_thw.shape[0]): + img_shape = grid_thw[img_idx, :].unsqueeze(0) + curr_img_size = img_shape.prod() + + pixel_values_curr_img = pixel_values[offset : offset + curr_img_size, :] + #breakpoint() + offset += curr_img_size + pixel_values_curr_img_padded, img_shape_padded = self.pad_multimodal_data(pixel_values_curr_img, img_shape) + + pixel_values_curr_img_padded, rot_pos_emb, cu_seqlens, cu_window_seqlens, window_index = self.visual.pre_attn( + pixel_values_curr_img_padded, img_shape_padded) + + assert pixel_values.shape[0] % 64 == 0, f"We need image h/w to be aligned to 112 for now. Which will make pixel_values be a multiple of (112/14)*(112/14)=64 (14 is patch size for ViT). Got pixel_values shape {pixel_values.shape[0]}" + + expanded_cu_seqlens = expand_to_max(cu_seqlens, 3) # either a single image, or a single image and its accompanying pad image, so only max expansion to 3 + htcore.mark_step() + hidden_states = self.visual(pixel_values_curr_img_padded, + rotary_pos_emb=rot_pos_emb, + cu_seqlens=expanded_cu_seqlens,) + htcore.mark_step() + image_embeds = self.visual.post_attn(hidden_states, window_index) + results += [image_embeds[:img_shape_padded[0].prod()//4, :]] # slice image_embeds to remove the padded parts. instead of hardcoding 4, maybe use config spatial merge etc + results_cat = torch.concat(results) + image_embeds = results_cat + else: + pixel_values, rot_pos_emb, cu_seqlens, cu_window_seqlens, window_index = self.visual.pre_attn( + pixel_values, grid_thw) + assert pixel_values.shape[0] % 64 == 0, f"We need image h/w to be aligned to 112 for now. Which will make pixel_values be a multiple of (112/14)*(112/14)=64 (14 is patch size for ViT). Got pixel_values shape {pixel_values.shape[0]}" + #print('.......', cu_seqlens, expand_to_max(cu_seqlens, 10)) + expanded_cu_seqlens = expand_to_max(cu_seqlens, 10) + htcore.mark_step() # padding in expand_to_max is dynamic + #breakpoint() + hidden_states = self.visual(pixel_values, + rotary_pos_emb=rot_pos_emb, + cu_seqlens=expanded_cu_seqlens,) + #cu_window_seqlens=cu_window_seqlens) + htcore.mark_step() + image_embeds = self.visual.post_attn(hidden_states, window_index) image_embeds = self.visual(pixel_values, grid_thw=grid_thw) # Split concatenated embeddings for each image item. @@ -964,6 +1141,18 @@ def _process_video_input( else: pixel_values_videos = video_input["pixel_values_videos"].type( self.visual.dtype) + + #Moved dynamic calculation to pre_attn, and post_attn and keep the visual() block to be static to include only VisionTransformer and VisionMerger. + pixel_values_videos, rot_pos_emb, cu_seqlens, cu_window_seqlens, window_index = self.visual.pre_attn( + pixel_values_videos, grid_thw) + expanded_cu_seqlens = expand_to_max(cu_seqlens, 10) + htcore.mark_step() # padding in expand_to_max is dynamic + hidden_states = self.visual(pixel_values_videos, + rotary_pos_emb=rot_pos_emb, + cu_seqlens=expanded_cu_seqlens,) + #cu_window_seqlens=cu_window_seqlens) + video_embeds = self.visual.post_attn(hidden_states, window_index) + #video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw) video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw) # Split concatenated embeddings for each video item. diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 3e969415a842..e55812d1a2a2 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -366,13 +366,15 @@ def _merge_multimodal_embeddings( assert isinstance(num_expected_tokens, int) flattened = _flatten_embeddings(multimodal_embeddings) - if flattened.shape[0] != num_expected_tokens: - expr = _embedding_count_expression(multimodal_embeddings) - raise ValueError( - f"Attempted to assign {expr} = {flattened.shape[0]} " - f"multimodal tokens to {num_expected_tokens} placeholders") + # if flattened.shape[0] != num_expected_tokens: + # expr = _embedding_count_expression(multimodal_embeddings) + # raise ValueError( + # f"Attempted to assign {expr} = {flattened.shape[0]} " + # f"multimodal tokens to {num_expected_tokens} placeholders") + + # flattened could have dummy data from the padding after num_expected_tokens + inputs_embeds[is_multimodal] = flattened[:num_expected_tokens, :] - inputs_embeds[is_multimodal] = flattened return inputs_embeds @@ -597,15 +599,12 @@ def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int): def make_empty_intermediate_tensors( batch_size: int, - context_size: int, dtype: torch.dtype, device: torch.device, ) -> IntermediateTensors: return IntermediateTensors({ key: - torch.zeros((batch_size, context_size, hidden_size), - dtype=dtype, - device=device) + torch.zeros((batch_size, hidden_size), dtype=dtype, device=device) for key in keys }) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 80bda7407f49..57550c9199e5 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -22,9 +22,10 @@ import habana_frameworks.torch.internal.bridge_config as bc import torch import vllm_hpu_extension.environment as environment -from vllm_hpu_extension.bucketing.common import get_bucketing_context +from vllm_hpu_extension.bucketing import HPUBucketingContext from vllm_hpu_extension.flags import enabled_flags from vllm_hpu_extension.ops import LoraMask as LoraMask +from vllm_hpu_extension.ops import batch2block, block2batch from vllm_hpu_extension.profiler import (HabanaHighLevelProfiler, HabanaMemoryProfiler, format_bytes) @@ -32,7 +33,7 @@ from vllm.attention.backends.abstract import AttentionType from vllm.attention.backends.hpu_attn import HPUAttentionImpl from vllm.config import DeviceConfig, VllmConfig -from vllm.distributed import broadcast_tensor_dict, get_pp_group +from vllm.distributed import broadcast_tensor_dict from vllm.distributed.parallel_state import get_world_group from vllm.forward_context import set_forward_context from vllm.inputs import INPUT_REGISTRY, InputRegistry @@ -82,8 +83,6 @@ VLLM_DELAYED_SAMPLING = os.environ.get('VLLM_DELAYED_SAMPLING', 'false').lower() == 'true' -VLLM_MERGED_PREFILL = os.environ.get('VLLM_MERGED_PREFILL', - 'false').lower() == 'true' DUMMY_TOKEN_ID = -1 @@ -157,16 +156,6 @@ def flatten(in_list): return list(itertools.chain(*in_list)) -def make_cpu_tensor(data, max_len, pad, dtype, flat) -> torch.Tensor: - if flat: - data = [flatten(data)] - return make_tensor_with_pad(data, - max_len=max_len, - pad=pad, - dtype=dtype, - device='cpu') - - def get_target_layer_suffix_list(model_type) -> list[str]: # This sets the suffix for the hidden layer name, which is controlled by # VLLM_CONFIG_HIDDEN_LAYERS. The default suffix is "DecoderLayer," which is @@ -240,10 +229,9 @@ def find_rope_layer(parent, path): return path_to_rope -class HpuModelAdapter(torch.nn.Module): +class HpuModelAdapter: def __init__(self, model, vllm_config, layer_names): - super().__init__() self.model = model self.prefill_use_fusedsdpa = "fsdpa" in enabled_flags() self.recompute_cos_sin = os.getenv('VLLM_COS_SIN_RECOMPUTE', @@ -252,11 +240,60 @@ def __init__(self, model, vllm_config, layer_names): self.block_size = vllm_config.cache_config.block_size self.dtype = vllm_config.model_config.dtype self.layer_names = layer_names + enforce_eager = vllm_config.model_config.enforce_eager self.is_pooler = hasattr(self.model, "_pooler") self.is_causal = True if self.is_pooler: self.set_causal_option(self.model) - self.use_merged_prefill = VLLM_MERGED_PREFILL + if not is_fake_hpu() and not htorch.utils.internal.is_lazy( + ) and not enforce_eager: + if os.getenv('VLLM_REGIONAL_COMPILATION', + 'true').lower() == 'true': + self.regional_compilation_layers_list = [ + RMSNorm, VocabParallelEmbedding + ] + self._regional_compilation(self.model) + else: + self.model = torch.compile(self.model, + backend='hpu_backend', + dynamic=False) + + model_config = getattr(self.model, "config", None) + self.model_is_mrope = uses_mrope(model_config) + + # For qwen2.5-VL model, we wrap visual model with disable_tensor_cache + # off due to handling of grid_thw. For langauge model, we wrap it with + # disable_tensor_cache on to save memory. Here we can either wrap it with + # self.model or self.model.language_model.model. + self.split_graph = self.model_is_mrope and os.getenv( + 'VLLM_QWEN_SPLIT_GRAPHS', 'false').lower() in ['1', 'true'] + + if htorch.utils.internal.is_lazy() and self.split_graph: + print("Split Graph to Visual and Language") + self.model.visual = htorch.hpu.wrap_in_hpu_graph( + self.model.visual, disable_tensor_cache=False) + self.model.language_model.model = htorch.hpu.wrap_in_hpu_graph( + self.model.language_model.model, disable_tensor_cache=True) + + def _regional_compilation(self, + module, + parent_module=None, + module_name=None): + if isinstance(module, torch.nn.ModuleList): + for children_name, children_module in module.named_children(): + self._compile_region(module, children_name, children_module) + elif any( + isinstance(module, layer) + for layer in self.regional_compilation_layers_list): + self._compile_region(parent_module, module_name, module) + else: + for children_name, children_module in module.named_children(): + self._regional_compilation(children_module, module, + children_name) + + def _compile_region(self, model, name, module): + module = torch.compile(module, backend='hpu_backend', dynamic=False) + setattr(model, name, module) def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device, dtype): @@ -266,9 +303,6 @@ def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device, or not attn_metadata.is_prompt): return attn_metadata - if attn_metadata.attn_bias is not None: - return attn_metadata - prefill_metadata = attn_metadata seq_lens_t = prefill_metadata.seq_lens_tensor @@ -345,10 +379,20 @@ def _set_block_mapping(self, metadata, batch_size, device, dtype): attn_bias=attn_bias) return metadata + def _set_block_scales(self, metadata, device): + block_mapping = metadata.block_mapping + ones = torch.ones((block_mapping.size(0), ), + device=device, + dtype=block_mapping.dtype) + sums = batch2block(block2batch(ones, block_mapping), block_mapping) + block_scales = torch.reciprocal(torch.maximum(ones, sums)) + metadata = metadata._replace(block_scales=block_scales) + return metadata + def _set_indices_and_offsets(self, metadata, block_size, is_prompt): slot_mapping = metadata.slot_mapping.flatten() indices = torch.div(slot_mapping, block_size, rounding_mode="floor") - if is_prompt and not self.use_merged_prefill: + if is_prompt: indices = indices.unflatten(0, (-1, block_size))[:, 0] offsets = None else: @@ -366,6 +410,7 @@ def _update_metadata(self, attn_metadata, batch_size, seq_len, device, else: attn_metadata = self._set_block_mapping(attn_metadata, batch_size, device, dtype) + attn_metadata = self._set_block_scales(attn_metadata, device) attn_metadata = self._set_indices_and_offsets(attn_metadata, self.block_size, attn_metadata.is_prompt) @@ -413,16 +458,57 @@ def forward(self, *args, **kwargs): input_ids.device, self.dtype) if 'lora_mask' in kwargs: LoraMask.setLoraMask(kwargs.pop('lora_mask')) - model_config = getattr(self.model, "config", None) - model_is_mrope = uses_mrope(model_config) - if self.layer_names is not None and not model_is_mrope: + + if self.layer_names is not None and not self.model_is_mrope: self._prepare_cos_sin(kwargs['positions']) + if self.model_is_mrope: # and self.split_graph: + if self.split_graph: + # Carry bypass_hpu_graphs to visual model forward. + bypass_hpu_graphs = kwargs.get('bypass_hpu_graphs', False) + self.model.visual.forward = functools.partial( + self.model.visual.forward, + bypass_hpu_graphs=bypass_hpu_graphs) + self.model.language_model.model.forward = functools.partial( + self.model.language_model.model.forward, + bypass_hpu_graphs=bypass_hpu_graphs) + #self.model.forward = functools.partial( + # self.model.forward, bypass_hpu_graphs=bypass_hpu_graphs) + + # For Qwen2.5-VL multimodal embedding, + # This embedding part should be always executed with PT_COMPILE_ONLY_MODE off + # at all time. We are turning it off here since it will be on during warmup run. + # Also, we are moving this code block to here from model.forward() since we don't want + # to wrap this with hpu_graph. This block has issue with disable_tensor_cache=true. + compile_only_mode_context = functools.partial( + bc.env_setting, "PT_COMPILE_ONLY_MODE", False) + + with compile_only_mode_context(): + #calculate embedding for multimodal + #breakpoint() + image_input = self.model._parse_and_validate_image_input( + **kwargs) + video_input = self.model._parse_and_validate_video_input( + **kwargs) + + if image_input is None and video_input is None: + inputs_embeds = None + else: + inputs_embeds = self.model.get_input_embeddings_v0( + input_ids, + image_input=image_input, + video_input=video_input) + input_ids = None + + kwargs.update({ + "input_ids": input_ids, + "inputs_embeds": inputs_embeds + }) + with set_forward_context(kwargs['attn_metadata'], self.vllm_config, virtual_engine): + #breakpoint() hidden_states = self.model(*args, **kwargs) - if not get_pp_group().is_last_rank: - return hidden_states hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) if selected_token_indices is not None: hidden_states = hidden_states.index_select( @@ -435,9 +521,6 @@ def compute_logits(self, *args, **kwargs): def sample(self, *args, **kwargs): return self.model.sample(*args, **kwargs) - def make_empty_intermediate_tensors(self, *args, **kwargs): - return self.model.make_empty_intermediate_tensors(*args, **kwargs) - def generate_proposals(self, *args, **kwargs): return self.model.generate_proposals(*args, **kwargs) @@ -652,10 +735,6 @@ def __init__( self.max_num_batched_tokens = \ self.scheduler_config.max_num_batched_tokens self.block_size = self.cache_config.block_size - self.use_merged_prefill = VLLM_MERGED_PREFILL - assert not (self.scheduler_config.use_padding_aware_scheduling - and self.use_merged_prefill), \ - 'Merged prefill is not compatible with padding aware scheduling!' self.pin_memory = is_pin_memory_available() self.kv_cache_dtype = self.cache_config.cache_dtype @@ -690,14 +769,13 @@ def __init__( self.profiler_counter_helper = HabanaProfilerCounterHelper() self.seen_configs: set = set() self._mem_margin: Optional[int] = None - HPUBucketingContext = get_bucketing_context() self.bucketing_ctx = HPUBucketingContext(self.max_num_seqs, self.max_num_prefill_seqs, self.block_size, - self.max_num_batched_tokens, - self.use_merged_prefill, - self.max_model_len) + self.max_num_batched_tokens) self.graphed_buckets: Set[Any] = set() + self.multimodal_buckets = [] #This should be use HPUBucketingContext + self.graphed_multimodal_buckets: Set[Any] = set() self._set_gc_threshold() if self.vllm_config.cache_config.enable_prefix_caching: @@ -847,12 +925,6 @@ def load_model(self) -> None: layer_names=path_to_rope) msg = f"Wrapping in HPU Graph took {m_wrap.get_summary_string()}" logger.info(msg) - with HabanaMemoryProfiler() as m_wrap: - self._maybe_compile(self.model, - vllm_config=self.vllm_config, - layer_names=path_to_rope) - msg = f"Compiling took {m_wrap.get_summary_string()}" - logger.info(msg) self.model_memory_usage = m.consumed_device_memory msg = f"Loading model weights took in total {m.get_summary_string()}" @@ -881,82 +953,32 @@ def _add_dummy_seq(self, seq_group_metadata_list, is_prompt): return seq_group_metadata_list, real_batch_size, batch_size_padded def _maybe_wrap_in_hpu_graph(self, *args, **kwargs): - return htorch.hpu.wrap_in_hpu_graph( - HpuModelAdapter(*args, **kwargs), - disable_tensor_cache=True, - ) if htorch.utils.internal.is_lazy() else HpuModelAdapter( - *args, **kwargs) - - def _maybe_compile(self, *args, **kwargs): - if not is_fake_hpu() and not htorch.utils.internal.is_lazy( - ) and not self.vllm_config.model_config.enforce_eager: - fullgraph = os.getenv('VLLM_T_COMPILE_FULLGRAPH', - 'false').strip().lower() in ("1", "true") - if os.getenv('VLLM_REGIONAL_COMPILATION', - 'true').strip().lower() in ("1", "true"): - compiled_methods = [self.model._set_block_mapping] - for method in compiled_methods: - method = torch.compile(method, - backend='hpu_backend', - fullgraph=fullgraph, - dynamic=False) - self.regional_compilation_layers_list = [ - RMSNorm, VocabParallelEmbedding - ] - self._regional_compilation(self.model, fullgraph) - else: - self.model = torch.compile(self.model, - backend='hpu_backend', - fullgraph=fullgraph, - dynamic=False) - - def _regional_compilation(self, - module, - fullgraph, - parent_module=None, - module_name=None): - if isinstance(module, torch.nn.ModuleList): - for children_name, children_module in module.named_children(): - self._compile_region(module, fullgraph, children_name, - children_module) - elif any( - isinstance(module, layer) - for layer in self.regional_compilation_layers_list): - self._compile_region( - parent_module, - fullgraph, - module_name, - module, - ) + self.split_graph = self.model_is_mrope and os.getenv( + 'VLLM_QWEN_SPLIT_GRAPHS', 'false').lower() in ['1', 'true'] + if htorch.utils.internal.is_lazy() and not self.split_graph: + return htorch.hpu.wrap_in_hpu_graph(HpuModelAdapter( + *args, **kwargs), + disable_tensor_cache=True) else: - for children_name, children_module in module.named_children(): - self._regional_compilation(children_module, fullgraph, module, - children_name) - - def _compile_region( - self, - model, - fullgraph, - name, - module, - ): - module = torch.compile(module, - backend='hpu_backend', - fullgraph=fullgraph, - dynamic=False) - setattr(model, name, module) + return HpuModelAdapter(*args, **kwargs) def get_model(self) -> torch.nn.Module: if isinstance(self.model, HpuModelAdapter): return self.model.model return self.model - def _use_graphs(self, batch_size, seq_len, is_prompt): + def _use_graphs(self, batch_size, seq_len, is_prompt, is_multimodal): if self.enforce_eager: return False if self.skip_warmup: return True - return (batch_size, seq_len, is_prompt) in self.graphed_buckets + if not is_multimodal or not self.graphed_multimodal_buckets: + return (batch_size, seq_len, is_prompt) in self.graphed_buckets + else: + #TODO:For now return TRUE for development + #This needs to be updated later with proper bucket detections. + #return (batch_size, height, weight) in self.graphed_multimodal_buckets + return True def _is_valid_bucket(self, bucket): return bucket[0] * bucket[1] <= self.max_num_batched_tokens @@ -1021,34 +1043,6 @@ def _get_mrope_positions_and_delta(self, seq_data, mm_kwargs, context_len): assert mrope_positions is not None return mrope_positions, mrope_position_delta - def make_attn_bias(self, seq_lens, max_prompt_len, dtype): - seq_pos = [list(range(sl)) for sl in seq_lens] - seq_idx = [[i] * sl for i, sl in enumerate(seq_lens)] - seq_pos_t = make_cpu_tensor(seq_pos, - max_len=max_prompt_len, - pad=-1, - dtype=torch.long, - flat=self.use_merged_prefill) - seq_idx_t = make_cpu_tensor(seq_idx, - max_len=max_prompt_len, - pad=-1, - dtype=torch.long, - flat=self.use_merged_prefill) - q_seq_idx_t = seq_idx_t.unsqueeze(-1) - kv_seq_idx_t = seq_idx_t.unsqueeze(-2) - q_seq_pos_t = seq_pos_t.unsqueeze(-1) - kv_seq_pos_t = seq_pos_t.unsqueeze(-2) - seq_idx_t = q_seq_idx_t != kv_seq_idx_t - seq_pos_t = kv_seq_pos_t > q_seq_pos_t - attn_mask = seq_idx_t | seq_pos_t - attn_bias = torch.zeros_like(attn_mask, dtype=dtype) - attn_bias.masked_fill_(attn_mask, -math.inf) - return attn_bias.unsqueeze(1) - - def move_to_device(self, tensor): - return tensor if tensor is None else tensor.to(self.device, - non_blocking=True) - def _prepare_prompt( self, seq_group_metadata_list: List[SequenceGroupMetadata], @@ -1102,10 +1096,6 @@ def _prepare_prompt( computed_block_nums) > 0 and self.sliding_window is None: # Prefix is not supported with sliding_window context_len = len(computed_block_nums) * self.block_size - if context_len == seq_len \ - and self.vllm_config.cache_config.enable_prefix_caching: - # Fully cached prompt - compute only last token - context_len = context_len - 1 prompt_tokens = prompt_tokens[context_len:] prefix_block_tables.append(computed_block_nums) elif self.scheduler_config.chunked_prefill_enabled: @@ -1145,6 +1135,9 @@ def _prepare_prompt( seq_group_metadata.mm_processor_kwargs, ) + # padding image patches (pixel_values, image_grid_thw) + #mm_kwargs = pad_multimodal_data(mm_kwargs) + # special processing for mrope position deltas. if self.model_is_mrope: mrope_positions, mrope_position_delta = \ @@ -1201,14 +1194,13 @@ def _prepare_prompt( slot = block_number * self.block_size + block_offset slot_mapping[-1].append(slot) - if self.use_merged_prefill: - target_query_len = sum(query_lens) - else: - target_query_len = max(query_lens) + max_query_len = max(query_lens) real_num_seqs = len(query_lens) + assert max_query_len > 0 + max_prompt_len = max( - self.bucketing_ctx.get_padded_prompt_seq_len(target_query_len), + self.bucketing_ctx.get_padded_prompt_seq_len(max_query_len), self.block_size) lora_ids: List[int] = [] @@ -1248,40 +1240,34 @@ def _prepare_prompt( else: prefix_block_list_tensor = None - input_tokens_tensor = make_cpu_tensor(input_tokens, - max_len=max_prompt_len, - pad=0, - dtype=torch.long, - flat=self.use_merged_prefill) + input_tokens_tensor = make_tensor_with_pad(input_tokens, + max_len=max_prompt_len, + pad=0, + dtype=torch.long, + device='cpu') + if self.model_is_mrope: input_positions = \ make_mrope_positions_tensor_with_pad(input_positions=input_positions, - input_mrope_positions=input_mrope_positions, - max_prompt_len=max_prompt_len, - pad=0) + input_mrope_positions=input_mrope_positions, + max_prompt_len=max_prompt_len, + pad=0) else: - input_positions = make_cpu_tensor(input_positions, - max_len=max_prompt_len, - pad=0, - dtype=torch.long, - flat=self.use_merged_prefill) - - slot_mapping = make_cpu_tensor(slot_mapping, - max_len=max_prompt_len, - pad=_PAD_SLOT_ID, - dtype=torch.long, - flat=self.use_merged_prefill) - - attn_bias = None - seq_lens_tensor = None - context_lens_tensor = None + input_positions = make_tensor_with_pad(input_positions, + max_len=max_prompt_len, + pad=0, + dtype=torch.long, + device='cpu') - if self.use_merged_prefill: - attn_bias = self.make_attn_bias(seq_lens, max_prompt_len, - self.model_config.dtype) + slot_mapping = make_tensor_with_pad(slot_mapping, + max_len=max_prompt_len, + pad=_PAD_SLOT_ID, + dtype=torch.long, + device='cpu') seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.long, device='cpu') + context_lens_tensor = torch.tensor(context_lens, dtype=torch.long, device='cpu') @@ -1295,15 +1281,18 @@ def _prepare_prompt( # Note: num_prefill_tokens is calculated using the length of # input_tokens after padding. num_prefill_tokens = input_tokens_tensor.numel() - - prefix_block_list_tensor = self.move_to_device( - prefix_block_list_tensor) - input_tokens_tensor = self.move_to_device(input_tokens_tensor) - input_positions = self.move_to_device(input_positions) - seq_lens_tensor = self.move_to_device(seq_lens_tensor) - slot_mapping = self.move_to_device(slot_mapping) - context_lens_tensor = self.move_to_device(context_lens_tensor) - attn_bias = self.move_to_device(attn_bias) + if prefix_block_list_tensor is not None: + prefix_block_list_tensor = prefix_block_list_tensor.to( + self.device, non_blocking=True) + input_tokens_tensor = input_tokens_tensor.to( # type: ignore + self.device, non_blocking=True) + input_positions = input_positions.to( # type: ignore + self.device, non_blocking=True) + slot_mapping = slot_mapping.to( # type: ignore + self.device, non_blocking=True) + seq_lens_tensor = seq_lens_tensor.to(self.device, non_blocking=True) + context_lens_tensor = context_lens_tensor.to(self.device, + non_blocking=True) attn_metadata = self.attn_backend.make_metadata( is_prompt=True, @@ -1312,10 +1301,11 @@ def _prepare_prompt( block_usage=None, block_indices=None, block_offsets=None, + block_scales=None, block_groups=None, - attn_bias=attn_bias, + attn_bias=None, seq_lens=seq_lens, - seq_lens_tensor=self.move_to_device(seq_lens_tensor), + seq_lens_tensor=seq_lens_tensor, context_lens_tensor=context_lens_tensor, num_prefills=real_num_seqs, num_prefill_tokens=num_prefill_tokens, @@ -1595,6 +1585,7 @@ def _prepare_decode( block_usage=block_usage, block_indices=None, block_offsets=None, + block_scales=None, block_groups=block_groups, attn_bias=None, seq_lens_tensor=None, @@ -1623,7 +1614,6 @@ def _prepare_decode( def prepare_input_tensors( self, seq_group_metadata_list: List[SequenceGroupMetadata], - finished_requests_ids: Optional[List[str]] = None ) -> Tuple[TModelInputForHPU, SamplingMetadata]: if len(seq_group_metadata_list) == 0: return self._model_input_cls(), None @@ -1681,14 +1671,9 @@ def prepare_input_tensors( ) = self._prepare_decode(decode_reqs) if not self.is_pooler: - generators = self.get_generators(finished_requests_ids) sampling_metadata = SamplingMetadata.prepare( - seq_group_metadata_list, - seq_lens, - query_lens, - self.device, - self.pin_memory, - generators=generators) + seq_group_metadata_list, seq_lens, query_lens, self.device, + self.pin_memory) if not self.scheduler_config.chunked_prefill_enabled: assert (len(prefill_reqs) and len(decode_reqs)) == 0 @@ -1712,32 +1697,29 @@ def prepare_input_tensors( lora_requests = decode_lora_requests lora_ids = decode_lora_ids - if self.is_pooler: - sampling_metadata = None - elif not self.use_merged_prefill: - # FIXME: We need to adjust selected_token_indices to accommodate - # for padding - max_len = input_tokens.size(1) - paddings = [max_len - q for q in query_lens] - paddings = [0] + paddings[:-1] - paddings = list(itertools.accumulate(paddings)) - paddings_prompt_logprobs = [] + # FIXME: We need to adjust selected_token_indices to accommodate + # for padding + max_len = input_tokens.size(1) + paddings = [max_len - q for q in query_lens] + paddings = [0] + paddings[:-1] + paddings = list(itertools.accumulate(paddings)) + paddings_prompt_logprobs = [] - if not self.is_pooler: - for i, seq_group_metadata in enumerate( - seq_group_metadata_list): - if seq_group_metadata.sampling_params \ - and seq_group_metadata.sampling_params.prompt_logprobs \ - is not None and seq_group_metadata.is_prompt: - paddings_prompt_logprobs += ([paddings[i]] * - seq_lens[i]) - - paddings = torch.tensor( - paddings_prompt_logprobs - if paddings_prompt_logprobs else paddings, - dtype=sampling_metadata.selected_token_indices.dtype, - device=sampling_metadata.selected_token_indices.device) - sampling_metadata.selected_token_indices.add_(paddings) + if not self.is_pooler: + for i, seq_group_metadata in enumerate(seq_group_metadata_list): + if seq_group_metadata.sampling_params \ + and seq_group_metadata.sampling_params.prompt_logprobs \ + is not None and seq_group_metadata.is_prompt: + paddings_prompt_logprobs += ([paddings[i]] * seq_lens[i]) + + paddings = torch.tensor( + paddings_prompt_logprobs + if paddings_prompt_logprobs else paddings, + dtype=sampling_metadata.selected_token_indices.dtype, + device=sampling_metadata.selected_token_indices.device) + sampling_metadata.selected_token_indices.add_(paddings) + else: + sampling_metadata = None if self.lora_config: lora_mapping = LoRAMapping( @@ -1807,76 +1789,6 @@ def prepare_input_tensors( lora_ids=lora_ids), \ sampling_metadata - def create_lora_mask(self, input_tokens: torch.Tensor, lora_ids: List[int], - is_prompt: bool): - ''' - This is a helper function to create the mask for lora computations. - Lora Mask is needed to ensure we match the correct lora weights for the - for the request. - For Prompt phase we have - lora_mask with shape (batch_size * seq_len, max_loras * max_rank) - lora_logits_mask with shape (batch_size, max_loras * max_rank) - For Decode phase we have both - lora_mask and lora_logits_mask with shape - (batch_size, max_loras * max_rank) - ''' - lora_mask: torch.Tensor = None - lora_logits_mask: torch.Tensor = None - lora_index = 0 - - if self.lora_config: - if is_prompt: - lora_mask = torch.zeros( - input_tokens.shape[0] * input_tokens.shape[1], - (self.lora_config.max_loras) *\ - self.lora_config.max_lora_rank, - dtype=self.lora_config.lora_dtype) - lora_logits_mask = torch.zeros( - input_tokens.shape[0], (self.lora_config.max_loras) * - self.lora_config.max_lora_rank, - dtype=self.lora_config.lora_dtype) - - ones = torch.ones(input_tokens.shape[1], - self.lora_config.max_lora_rank, - dtype=self.lora_config.lora_dtype) - logit_ones = torch.ones(1, - self.lora_config.max_lora_rank, - dtype=self.lora_config.lora_dtype) - - for i in range(len(lora_ids)): - if lora_ids[i] == 0: - continue - lora_index = self.lora_manager._adapter_manager.\ - lora_index_to_id.index(lora_ids[i]) - start_row = i * input_tokens.shape[1] - end_row = start_row + input_tokens.shape[1] - start_col = lora_index * self.lora_config.max_lora_rank - end_col = start_col + self.lora_config.max_lora_rank - lora_mask[start_row:end_row, start_col:end_col] = ones - lora_logits_mask[i, start_col:end_col] = logit_ones - lora_mask = lora_mask.to('hpu') - lora_logits_mask = lora_logits_mask.to('hpu') - else: - lora_mask = torch.zeros(input_tokens.shape[0], - (self.lora_config.max_loras) * - self.lora_config.max_lora_rank, - dtype=self.lora_config.lora_dtype) - ones = torch.ones(1, - self.lora_config.max_lora_rank, - dtype=self.lora_config.lora_dtype) - for i in range(len(lora_ids)): - if lora_ids[i] == 0: - continue - lora_index = self.lora_manager._adapter_manager.\ - lora_index_to_id.index(lora_ids[i]) - start_pos = lora_index * self.lora_config.max_lora_rank - end_pos = start_pos + self.lora_config.max_lora_rank - lora_mask[i, start_pos:end_pos] = ones - lora_mask = lora_mask.to('hpu') - lora_logits_mask = lora_mask - - return lora_mask, lora_logits_mask - def _seq_len(self, attn_metadata): if attn_metadata.num_prefills != 0: return attn_metadata.slot_mapping.size(1) @@ -1915,10 +1827,63 @@ def trim_attn_metadata(self, metadata: AttentionMetadata) -> object: 'is_prompt', 'block_indices', 'block_offsets', + 'block_scales', 'block_groups', ]) return attention_metadata + def create_dummy_multi_modal_seq_group_metadata(self, group_id, seq_len, + lora_request, temperature, + height, width): + + from vllm.multimodal.utils import cached_get_tokenizer + + if self.is_pooler: + sampling_params = None + else: + sampling_params = SamplingParams(temperature=temperature) + + assert self.mm_registry.has_processor(self.model_config) + tokenizer = cached_get_tokenizer( + self.model_config.tokenizer, + trust_remote_code=self.model_config.trust_remote_code, + ) + processor = self.mm_registry.create_processor(self.model_config, + tokenizer) + mm_counts = self.mm_registry.get_mm_limits_per_prompt( + self.model_config) + #mm_counts = {"image":1} + print("mm_counts:", mm_counts) + factory = processor.dummy_inputs + processor_inputs = factory.get_dummy_processor_inputs( + seq_len=seq_len, + mm_counts=mm_counts, + image_width=width, + image_height=height) + + mm_inputs = processor.apply( + prompt=processor_inputs.prompt_text, + mm_data=processor_inputs.mm_data, + hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs, + ) + + prompt_token_ids = mm_inputs["prompt_token_ids"] + placeholders_by_modality = mm_inputs["mm_placeholders"] + + prompt_token_ids.extend([0] * (seq_len - len(prompt_token_ids))) + seq_data = SequenceData.from_seqs(prompt_token_ids) + + return SequenceGroupMetadata( + request_id=str(group_id), + is_prompt=True, + seq_data={group_id: seq_data}, + sampling_params=sampling_params, + block_tables=None, + lora_request=lora_request[group_id] if lora_request else None, + multi_modal_data=mm_inputs["mm_kwargs"], + multi_modal_placeholders=placeholders_by_modality, + ) + def create_dummy_seq_group_metadata(self, group_id, seq_len, @@ -1931,6 +1896,7 @@ def create_dummy_seq_group_metadata(self, sampling_params = SamplingParams(temperature=temperature) num_blocks = math.ceil(seq_len / self.block_size) seq_len = max(seq_len, 1) + if is_prompt: input_len = seq_len output_len = 0 @@ -1956,12 +1922,22 @@ def profile_run(self) -> None: kv_caches = [None] * num_layers bind_kv_cache( self.vllm_config.compilation_config.static_forward_context, - [kv_caches] * self.parallel_config.pipeline_parallel_size) - _, max_seq_len = self.bucketing_ctx.get_max_prompt_shape() - max_batch_size = min(self.max_num_seqs, - self.max_num_batched_tokens // max_seq_len) - self.warmup_scenario(max_batch_size, max_seq_len, True, kv_caches, - False, True) + [kv_caches]) + # FIXME Going to set this to on big batch indepedent of bucketing_ctx + # _, max_seq_len = self.bucketing_ctx.get_max_prompt_shape() + # max_batch_size = min(self.max_num_seqs, + # self.max_num_batched_tokens // max_seq_len) + max_batch_size = 1 + max_seq_len = self.max_num_batched_tokens + self.warmup_scenario( + batch_size=max_batch_size, + seq_len=max_seq_len, + is_prompt=True, + kv_caches=kv_caches, + is_pt_profiler_run=False, + is_lora_profile_run=True, + multimodal_seqs_group_metada=True, + ) return def warmup_scenario(self, @@ -1971,8 +1947,13 @@ def warmup_scenario(self, kv_caches, is_pt_profiler_run=False, is_lora_profile_run=False, - temperature=0) -> None: - use_graphs = self._use_graphs(batch_size, seq_len, is_prompt) + multimodal_seqs_group_metada=False, + temperature=0, + height=None, + width=None, + return_time=False) -> None: + use_graphs = self._use_graphs(batch_size, seq_len, is_prompt, + multimodal_seqs_group_metada) scenario_name = ("warmup_" f"{'prompt' if is_prompt else 'decode'}_" f"bs{batch_size}_" @@ -1982,6 +1963,7 @@ def warmup_scenario(self, # that will have unique loras, an therefore the max amount of memory # consumption create dummy lora request copies from the lora request # passed in, which contains a lora from the lora warmup path. + dummy_lora_requests: List[LoRARequest] = [] dummy_lora_requests_per_seq: List[LoRARequest] = [] if self.lora_config and is_lora_profile_run: @@ -2003,7 +1985,21 @@ def warmup_scenario(self, ] self.profiler.start('internal', scenario_name) times = 3 if use_graphs or is_pt_profiler_run else 1 - if is_prompt: + if return_time: + times += 1 + + if multimodal_seqs_group_metada: + seqs = [ + self.create_dummy_multi_modal_seq_group_metadata( + group_id=i, + seq_len=seq_len, + lora_request=dummy_lora_requests_per_seq[i] + if dummy_lora_requests_per_seq else None, + temperature=temperature, + height=height, + width=width) for i in range(batch_size) + ] + elif is_prompt: seqs = [ self.create_dummy_seq_group_metadata( i, @@ -2033,21 +2029,12 @@ def warmup_scenario(self, profiler.start() for _ in range(times): inputs = self.prepare_model_input(seqs) + if return_time: + tstart = time.time() is_single_step = \ self.vllm_config.scheduler_config.num_scheduler_steps == 1 if is_prompt or is_single_step: - intermediate_tensors = None - if not get_pp_group().is_first_rank: - intermediate_tensors = \ - self.model.make_empty_intermediate_tensors( - batch_size=batch_size, - context_size=seq_len if is_prompt else 1, - dtype=self.model_config.dtype, - device=self.device) - self.execute_model(inputs, - kv_caches, - intermediate_tensors=intermediate_tensors, - warmup_mode=True) + self.execute_model(inputs, kv_caches, warmup_mode=True) else: # decode with multi-step inputs = dataclasses.replace(inputs, is_first_multi_step=True, @@ -2060,18 +2047,22 @@ def warmup_scenario(self, inputs = dataclasses.replace(inputs, is_first_multi_step=False, is_last_step=True) + # TODO: why 2 execute_model? self.execute_model(inputs, kv_caches, warmup_mode=True, num_steps=2, seqs=seqs) torch.hpu.synchronize() + if return_time: + t_total = time.time() - tstart if profiler: profiler.step() if profiler: profiler.stop() self.profiler.end() gc.collect() + return t_total if return_time else None def remove_all_loras(self): if not self.lora_manager: @@ -2116,11 +2107,67 @@ def log_warmup(self, phase, i, max_i, batch_size, seq_len): f"free_mem:{free_mem}") logger.info(msg) + def log_warmup_multimodal(self, phase, i, max_i, batch_size, seq_len, + height, width): + free_mem = format_bytes( + HabanaMemoryProfiler.current_free_device_memory()) + dim = "num_blocks" + if "Prompt" in phase: + dim = "seq_len" + msg = (f"[Warmup][{phase}][{i+1}/{max_i}] " + f"batch_size:{batch_size} " + f"{dim}:{seq_len}", f"hw:({height},{width})", + f"free_mem:{free_mem}") + logger.info(msg) + def warmup_all_buckets(self, buckets, is_prompt, kv_caches): + # TODO: The plan here is loop over a couple of image + # resolutions and see if that helps during the warmup + # somehow indepedent of the batch_size, seq_len + # might need to mark.step() somewhere to split the + # HPU graph for video and language model + + # Warmup Multimodal with fixed seq_len + if not hasattr(self, 'visual_warmup_times'): + self.visual_warmup_times = {} + for i, (h, w) in enumerate(self.multimodal_buckets): + max_batch_size = 1 #TODO: For now we hardcoded batch 1. + max_seq_len = 2048 #TODO: set with VLLM_PROMPT_SEQ_BUCKET_MAX + self.log_warmup_multimodal('Image', i, max_seq_len, max_batch_size, + max_seq_len, h, w) + assert h%112 == 0 and w % 112 == 0, "Expected to be 112 aligned for now" + t = self.warmup_scenario(batch_size=max_batch_size, + seq_len=max_seq_len, + is_prompt=True, + kv_caches=kv_caches, + is_pt_profiler_run=False, + is_lora_profile_run=True, + multimodal_seqs_group_metada=True, + height=h, + width=w, + return_time=True) + #if ((h*w) / (14*14)) in self.visual_warmup_times: + #breakpoint() + #print() + self.visual_warmup_times[((h*w) / (14*14))] = self.visual_warmup_times.get(((h*w) / (14*14)), []) + [('nograph', t)] # TODO hardcoded "14" remove. "14" is a model specific number, maybe (h,w) or h*w is a better key? + + #Warmup without multimodal for text-prompt only + #TODO: We might need to warmup with smaller multimodal to generate + #3D position tensor for multimodal model. for i, (batch_size, seq_len) in enumerate(reversed(buckets)): self.log_warmup('Prompt' if is_prompt else 'Decode', i, len(buckets), batch_size, seq_len) - self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches) + #self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches) + if is_prompt: + self.warmup_scenario(batch_size, + seq_len, + is_prompt, + kv_caches, + multimodal_seqs_group_metada=True, + height=112, + width=112) # everythign must be 112 aligned (for now) + else: + self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches) def warmup_graphs(self, strategy, @@ -2159,12 +2206,16 @@ def warmup_graphs(self, self.graphed_buckets.add(graphed_bucket) self.log_warmup(phase, idx, num_candidates, batch_size, seq_len) with HabanaMemoryProfiler() as mem_prof: - self.warmup_scenario(batch_size, - seq_len, - is_prompt, - kv_caches, - temperature=1.0 if batch_size - not in warmed_random_sampler_bs else 0) + self.warmup_scenario( + batch_size, + seq_len, + is_prompt, + kv_caches, + temperature=1.0 + if batch_size not in warmed_random_sampler_bs else 0, + multimodal_seqs_group_metada=True if is_prompt else False, + height=112, + width=112) # everythign must be 112 aligned (for now) warmed_random_sampler_bs.add(batch_size) used_mem = align_workers(mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX) @@ -2172,6 +2223,34 @@ def warmup_graphs(self, total_mem += used_mem total_batch_seq += batch_seq + # TODO: Multimodal HPU graph warmup need to be also check Memory, + # and drop some buckets if memory is not sufficient. + print("WARMUP MULTIMODAL IMAGE GRAPH") + for idx, (h, w) in enumerate(self.multimodal_buckets): + graphed_multimodal_buckets = (1, h, w) + if graphed_multimodal_buckets in self.graphed_multimodal_buckets: + continue + self.graphed_multimodal_buckets.add(graphed_multimodal_buckets) + + for i, (b, h, w) in enumerate(self.graphed_multimodal_buckets): + max_batch_size = 1 #TODO: For now we hardcoded batch 1. + max_seq_len = 2048 #TODO: set with VLLM_PROMPT_SEQ_BUCKET_MAX (1680x1680 error on HPU GRAPH) + self.log_warmup_multimodal('Graph/Image', i, max_seq_len, + max_batch_size, max_seq_len, h, w) + t = self.warmup_scenario( + batch_size=max_batch_size, + seq_len=max_seq_len, + is_prompt=True, + kv_caches=kv_caches, + #is_pt_profiler_run=False, + #is_lora_profile_run=True, + multimodal_seqs_group_metada=True, + height=h, + width=w, + return_time=True) + self.visual_warmup_times[((h*w) / (14*14))] = self.visual_warmup_times.get(((h*w) / (14*14)), []) + [('graph', t)] + + return total_mem, total_batch_seq, captured_all def log_graph_warmup_summary(self, buckets, is_prompt, total_mem): @@ -2189,10 +2268,6 @@ def log_graph_warmup_summary(self, buckets, is_prompt, total_mem): @torch.inference_mode() def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: - if not self.is_pooler: - max_blocks = kv_caches[0][0].size(0) - self.bucketing_ctx.generate_decode_buckets(max_blocks) - if profile := os.environ.get('VLLM_PT_PROFILE', None): phase, bs, seq_len, graph = profile.split('_') is_prompt = phase == 'prompt' @@ -2202,6 +2277,35 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: self.warmup_scenario(int(bs), int(seq_len), is_prompt, kv_caches, True) raise AssertionError("Finished profiling") + if not self.is_pooler: + max_blocks = kv_caches[0][0].size(0) + self.bucketing_ctx.generate_prompt_buckets() + + if supports_multimodal(self.model.model): + if True: + FIXED_MULTIMODAL_BUCKETS = self.model.model.FIXED_MULTIMODAL_BUCKETS + # [1600, 3200, 4800, 6400, 9600] + # 1600 means an image with 1600*14*14 ppixels, ie 560x560 + self.multimodal_buckets = [[112, total_size * 14 * 14 / 112] for total_size in FIXED_MULTIMODAL_BUCKETS] + # TODO This is qwen2.5vl/model specific code here. This should come from model file? + else: + #TODO: + # Multimodal buckets are based on H,W , it should be changed to be aligned with multimodal paddings. + # Also need to move to HPUBucketingContext. + #Multimodal bucket : [[560, 560], [560, 1120], [560, 1680], [1120, 560], [1120, 1120], [1120, 1680], [1680, 560], [1680, 1120], [1680, 1680]] + VLLM_MULTIMODAL_BUCKET = 560 #Pick number divisible by 28(patchsize*mergesize), this can be env. + max_seq_len = 1120 #2048 #self.max_num_batched_tokens + bucket = VLLM_MULTIMODAL_BUCKET + # TODO this number and self.FIXED_MULTIMODAL_BUCKETS should be in sync + self.multimodal_buckets = [ + [h, w] for h in range(bucket, max_seq_len + 1, bucket) + for w in range(bucket, max_seq_len + 1, bucket) + ] + breakpoint() + print("Multimodal bucket :", self.multimodal_buckets) + + if not self.is_pooler: + self.bucketing_ctx.generate_decode_buckets(max_blocks) if not htorch.utils.internal.is_lazy() and not self.enforce_eager: multiplier = 3 if os.getenv('VLLM_REGIONAL_COMPILATION', 'true').lower() == 'true' else 1 @@ -2235,6 +2339,7 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: logger.warning('Cannot use PT_COMPILE_ONLY_MODE. ' 'Warmup time will be negatively impacted. ' 'Please update Gaudi Software Suite.') + with compile_only_mode_context( ) if can_use_compile_only_mode else contextlib.nullcontext(): self.warmup_all_buckets(self.bucketing_ctx.prompt_buckets, True, @@ -2242,7 +2347,7 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: if not self.is_pooler: self.warmup_all_buckets(self.bucketing_ctx.decode_buckets, False, kv_caches) - + #breakpoint() # self.visual_warmup_times is populated at this point (but without hpu graphs) if not self.enforce_eager and htorch.utils.internal.is_lazy(): if not self.is_pooler: assert self.mem_margin is not None, \ @@ -2348,6 +2453,15 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: f"Warmup finished in {elapsed_time:.0f} secs, " f"allocated {format_bytes(end_mem - start_mem)} of device memory") logger.info(msg) + #breakpoint() # self.visual_warmup_times is populated at this point (with and without hpu graphs) + # Another way to do this is for qwen2.5vl model (or any multimodal model) to track if a new shape is incoming, and then enable a timer. + # then this "time collection" logic is hidden in model file itself, and model_runner isnt tainted with it + # Also inside the model file, we may get a better estimate of the time. right now the time is a proxy as it also contains "text time" (though all text inp is of same len (2048)?) + # bt we'd need to markstep/sync=True if we are collecting times inside + if hasattr(self, 'visual_warmup_times'): + summary = {k: min([t for _, t in self.visual_warmup_times[k]]) for k in self.visual_warmup_times} + self.visual_warmup_times = summary + self.model.visual_warmup_times = self.visual_warmup_times # a strange way to pass this in, but gonna charge ahead for now with this self.profiler.end() def finish_measurements(self): @@ -2506,7 +2620,7 @@ def prepare_model_input( self.profiler_counter_helper.capture_seq_group_metadata_stats( seq_group_metadata_list=seq_group_metadata_list) model_input, sampling_metadata = self.prepare_input_tensors( - seq_group_metadata_list, finished_requests_ids) + seq_group_metadata_list) assert model_input.attn_metadata is not None is_prompt = model_input.attn_metadata.is_prompt @@ -2515,6 +2629,76 @@ def prepare_model_input( is_prompt=is_prompt, virtual_engine=virtual_engine) + def create_lora_mask(self, input_tokens: torch.Tensor, lora_ids: List[int], + is_prompt: bool): + ''' + This is a helper function to create the mask for lora computations. + Lora Mask is needed to ensure we match the correct lora weights for the + for the request. + For Prompt phase we have + lora_mask with shape (batch_size * seq_len, max_loras * max_rank) + lora_logits_mask with shape (batch_size, max_loras * max_rank) + For Decode phase we have both + lora_mask and lora_logits_mask with shape + (batch_size, max_loras * max_rank) + ''' + lora_mask: torch.Tensor = None + lora_logits_mask: torch.Tensor = None + lora_index = 0 + + if self.lora_config: + if is_prompt: + lora_mask = torch.zeros( + input_tokens.shape[0] * input_tokens.shape[1], + (self.lora_config.max_loras) *\ + self.lora_config.max_lora_rank, + dtype=self.lora_config.lora_dtype) + lora_logits_mask = torch.zeros( + input_tokens.shape[0], (self.lora_config.max_loras) * + self.lora_config.max_lora_rank, + dtype=self.lora_config.lora_dtype) + + ones = torch.ones(input_tokens.shape[1], + self.lora_config.max_lora_rank, + dtype=self.lora_config.lora_dtype) + logit_ones = torch.ones(1, + self.lora_config.max_lora_rank, + dtype=self.lora_config.lora_dtype) + + for i in range(len(lora_ids)): + if lora_ids[i] == 0: + continue + lora_index = self.lora_manager._adapter_manager.\ + lora_index_to_id.index(lora_ids[i]) + start_row = i * input_tokens.shape[1] + end_row = start_row + input_tokens.shape[1] + start_col = lora_index * self.lora_config.max_lora_rank + end_col = start_col + self.lora_config.max_lora_rank + lora_mask[start_row:end_row, start_col:end_col] = ones + lora_logits_mask[i, start_col:end_col] = logit_ones + lora_mask = lora_mask.to('hpu') + lora_logits_mask = lora_logits_mask.to('hpu') + else: + lora_mask = torch.zeros(input_tokens.shape[0], + (self.lora_config.max_loras) * + self.lora_config.max_lora_rank, + dtype=self.lora_config.lora_dtype) + ones = torch.ones(1, + self.lora_config.max_lora_rank, + dtype=self.lora_config.lora_dtype) + for i in range(len(lora_ids)): + if lora_ids[i] == 0: + continue + lora_index = self.lora_manager._adapter_manager.\ + lora_index_to_id.index(lora_ids[i]) + start_pos = lora_index * self.lora_config.max_lora_rank + end_pos = start_pos + self.lora_config.max_lora_rank + lora_mask[i, start_pos:end_pos] = ones + lora_mask = lora_mask.to('hpu') + lora_logits_mask = lora_mask + + return lora_mask, lora_logits_mask + def _get_seq_ids(self, model_input): return ([ sg.seq_ids[0] for sg in model_input.sampling_metadata.seq_groups @@ -2544,9 +2728,6 @@ def execute_model( use_delayed_sampling = VLLM_DELAYED_SAMPLING and not warmup_mode assert not (use_delayed_sampling and num_steps != 1), \ 'Delayed sampling is not compatible with MSS!' - assert not (use_delayed_sampling and - self.parallel_config.pipeline_parallel_size != 1), \ - 'Delayed sampling is not compatible with Pipeline Parallelism!' assert model_input.input_tokens is not None if use_delayed_sampling and not model_input.is_prompt and \ self.is_driver_worker: @@ -2617,7 +2798,8 @@ def execute_model( assert is_prompt is not None batch_size = input_tokens.size(0) seq_len = self._seq_len(attn_metadata) - use_graphs = self._use_graphs(batch_size, seq_len, is_prompt) + use_graphs = self._use_graphs(batch_size, seq_len, is_prompt, + self.model_is_mrope) self._check_config(batch_size, seq_len, attn_metadata, warmup_mode) lora_mask: torch.Tensor = None @@ -2627,6 +2809,7 @@ def execute_model( lora_mask, lora_logits_mask = self.create_lora_mask( input_tokens, model_input.lora_ids, attn_metadata.is_prompt) + execute_model_kwargs = { "input_ids": input_tokens, "positions": input_positions, @@ -2694,6 +2877,7 @@ def try_revert_dummy_output_tokens(): with self.profiler.record_event('internal', model_event_name, args=profiler_args): + hidden_states = self.model.forward( **execute_model_kwargs, selected_token_indices=sampling_metadata. @@ -2703,8 +2887,6 @@ def try_revert_dummy_output_tokens(): LoraMask.setLoraMask( lora_logits_mask.index_select( 0, sampling_metadata.selected_token_indices)) - if not get_pp_group().is_last_rank: - return hidden_states # Compute the logits. with self.profiler.record_event( @@ -2725,8 +2907,6 @@ def try_revert_dummy_output_tokens(): if use_delayed_sampling: fake_output = self._delayed_sampler_outputs(model_input) - elif model_input.async_callback is not None: - model_input.async_callback() with self.profiler.record_event( 'internal', ('sample_' @@ -2748,8 +2928,7 @@ def try_revert_dummy_output_tokens(): self.cached_step_outputs.append(output) self.cached_step_inputs.append(model_input) htorch.core.mark_step() - if use_delayed_sampling \ - and model_input.async_callback is not None: + if model_input.async_callback is not None: model_input.async_callback() if i < num_steps - 1: if i == 0: @@ -2850,10 +3029,18 @@ def try_revert_dummy_output_tokens(): else: return [] + #from habana_frameworks.torch.hpu.metrics import metric_global + #gc_metric = metric_global("graph_compilation") + #print(" -------- graph_compilation: ", gc_metric.stats()) + return [output] if self.is_driver_worker else [] else: return [] + #from habana_frameworks.torch.hpu.metrics import metric_global + #gc_metric = metric_global("graph_compilation") + #print(" -------- graph_compilation: ", gc_metric.stats()) + return output if type(output) is list else [output] def _delayed_sampler_outputs(self, model_input): From 4cdc7d74dc4f4560be5ead855271e3f3bed46d0f Mon Sep 17 00:00:00 2001 From: Iman Gohari Date: Mon, 14 Apr 2025 20:37:43 +0000 Subject: [PATCH 03/38] fea(): Added the changes needed from hpu-extension #61 --- vllm/worker/hpu_model_runner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 57550c9199e5..cf7ff1e5d781 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -22,7 +22,7 @@ import habana_frameworks.torch.internal.bridge_config as bc import torch import vllm_hpu_extension.environment as environment -from vllm_hpu_extension.bucketing import HPUBucketingContext +from vllm_hpu_extension.bucketing.linear import HPUBucketingContext from vllm_hpu_extension.flags import enabled_flags from vllm_hpu_extension.ops import LoraMask as LoraMask from vllm_hpu_extension.ops import batch2block, block2batch @@ -772,7 +772,7 @@ def __init__( self.bucketing_ctx = HPUBucketingContext(self.max_num_seqs, self.max_num_prefill_seqs, self.block_size, - self.max_num_batched_tokens) + self.max_num_batched_tokens, False) self.graphed_buckets: Set[Any] = set() self.multimodal_buckets = [] #This should be use HPUBucketingContext self.graphed_multimodal_buckets: Set[Any] = set() From 254ca6b43fa81a70a38f77e4ac9af69c6ca418ba Mon Sep 17 00:00:00 2001 From: Iman Gohari Date: Mon, 14 Apr 2025 21:39:05 +0000 Subject: [PATCH 04/38] reverted the hup_model_runner to habana_main and added the qwen2.5-vl changes --- vllm/worker/hpu_model_runner.py | 556 ++++++++++++++++++-------------- 1 file changed, 311 insertions(+), 245 deletions(-) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index cf7ff1e5d781..db32b1da2f46 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -22,10 +22,9 @@ import habana_frameworks.torch.internal.bridge_config as bc import torch import vllm_hpu_extension.environment as environment -from vllm_hpu_extension.bucketing.linear import HPUBucketingContext +from vllm_hpu_extension.bucketing.common import get_bucketing_context from vllm_hpu_extension.flags import enabled_flags from vllm_hpu_extension.ops import LoraMask as LoraMask -from vllm_hpu_extension.ops import batch2block, block2batch from vllm_hpu_extension.profiler import (HabanaHighLevelProfiler, HabanaMemoryProfiler, format_bytes) @@ -33,7 +32,7 @@ from vllm.attention.backends.abstract import AttentionType from vllm.attention.backends.hpu_attn import HPUAttentionImpl from vllm.config import DeviceConfig, VllmConfig -from vllm.distributed import broadcast_tensor_dict +from vllm.distributed import broadcast_tensor_dict, get_pp_group from vllm.distributed.parallel_state import get_world_group from vllm.forward_context import set_forward_context from vllm.inputs import INPUT_REGISTRY, InputRegistry @@ -83,6 +82,8 @@ VLLM_DELAYED_SAMPLING = os.environ.get('VLLM_DELAYED_SAMPLING', 'false').lower() == 'true' +VLLM_MERGED_PREFILL = os.environ.get('VLLM_MERGED_PREFILL', + 'false').lower() == 'true' DUMMY_TOKEN_ID = -1 @@ -156,6 +157,16 @@ def flatten(in_list): return list(itertools.chain(*in_list)) +def make_cpu_tensor(data, max_len, pad, dtype, flat) -> torch.Tensor: + if flat: + data = [flatten(data)] + return make_tensor_with_pad(data, + max_len=max_len, + pad=pad, + dtype=dtype, + device='cpu') + + def get_target_layer_suffix_list(model_type) -> list[str]: # This sets the suffix for the hidden layer name, which is controlled by # VLLM_CONFIG_HIDDEN_LAYERS. The default suffix is "DecoderLayer," which is @@ -229,9 +240,10 @@ def find_rope_layer(parent, path): return path_to_rope -class HpuModelAdapter: +class HpuModelAdapter(torch.nn.Module): def __init__(self, model, vllm_config, layer_names): + super().__init__() self.model = model self.prefill_use_fusedsdpa = "fsdpa" in enabled_flags() self.recompute_cos_sin = os.getenv('VLLM_COS_SIN_RECOMPUTE', @@ -240,23 +252,11 @@ def __init__(self, model, vllm_config, layer_names): self.block_size = vllm_config.cache_config.block_size self.dtype = vllm_config.model_config.dtype self.layer_names = layer_names - enforce_eager = vllm_config.model_config.enforce_eager self.is_pooler = hasattr(self.model, "_pooler") self.is_causal = True if self.is_pooler: self.set_causal_option(self.model) - if not is_fake_hpu() and not htorch.utils.internal.is_lazy( - ) and not enforce_eager: - if os.getenv('VLLM_REGIONAL_COMPILATION', - 'true').lower() == 'true': - self.regional_compilation_layers_list = [ - RMSNorm, VocabParallelEmbedding - ] - self._regional_compilation(self.model) - else: - self.model = torch.compile(self.model, - backend='hpu_backend', - dynamic=False) + self.use_merged_prefill = VLLM_MERGED_PREFILL model_config = getattr(self.model, "config", None) self.model_is_mrope = uses_mrope(model_config) @@ -275,26 +275,6 @@ def __init__(self, model, vllm_config, layer_names): self.model.language_model.model = htorch.hpu.wrap_in_hpu_graph( self.model.language_model.model, disable_tensor_cache=True) - def _regional_compilation(self, - module, - parent_module=None, - module_name=None): - if isinstance(module, torch.nn.ModuleList): - for children_name, children_module in module.named_children(): - self._compile_region(module, children_name, children_module) - elif any( - isinstance(module, layer) - for layer in self.regional_compilation_layers_list): - self._compile_region(parent_module, module_name, module) - else: - for children_name, children_module in module.named_children(): - self._regional_compilation(children_module, module, - children_name) - - def _compile_region(self, model, name, module): - module = torch.compile(module, backend='hpu_backend', dynamic=False) - setattr(model, name, module) - def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device, dtype): if (attn_metadata is None @@ -303,6 +283,9 @@ def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device, or not attn_metadata.is_prompt): return attn_metadata + if attn_metadata.attn_bias is not None: + return attn_metadata + prefill_metadata = attn_metadata seq_lens_t = prefill_metadata.seq_lens_tensor @@ -379,20 +362,10 @@ def _set_block_mapping(self, metadata, batch_size, device, dtype): attn_bias=attn_bias) return metadata - def _set_block_scales(self, metadata, device): - block_mapping = metadata.block_mapping - ones = torch.ones((block_mapping.size(0), ), - device=device, - dtype=block_mapping.dtype) - sums = batch2block(block2batch(ones, block_mapping), block_mapping) - block_scales = torch.reciprocal(torch.maximum(ones, sums)) - metadata = metadata._replace(block_scales=block_scales) - return metadata - def _set_indices_and_offsets(self, metadata, block_size, is_prompt): slot_mapping = metadata.slot_mapping.flatten() indices = torch.div(slot_mapping, block_size, rounding_mode="floor") - if is_prompt: + if is_prompt and not self.use_merged_prefill: indices = indices.unflatten(0, (-1, block_size))[:, 0] offsets = None else: @@ -410,7 +383,6 @@ def _update_metadata(self, attn_metadata, batch_size, seq_len, device, else: attn_metadata = self._set_block_mapping(attn_metadata, batch_size, device, dtype) - attn_metadata = self._set_block_scales(attn_metadata, device) attn_metadata = self._set_indices_and_offsets(attn_metadata, self.block_size, attn_metadata.is_prompt) @@ -507,8 +479,9 @@ def forward(self, *args, **kwargs): with set_forward_context(kwargs['attn_metadata'], self.vllm_config, virtual_engine): - #breakpoint() hidden_states = self.model(*args, **kwargs) + if not get_pp_group().is_last_rank: + return hidden_states hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) if selected_token_indices is not None: hidden_states = hidden_states.index_select( @@ -521,6 +494,9 @@ def compute_logits(self, *args, **kwargs): def sample(self, *args, **kwargs): return self.model.sample(*args, **kwargs) + def make_empty_intermediate_tensors(self, *args, **kwargs): + return self.model.make_empty_intermediate_tensors(*args, **kwargs) + def generate_proposals(self, *args, **kwargs): return self.model.generate_proposals(*args, **kwargs) @@ -735,6 +711,10 @@ def __init__( self.max_num_batched_tokens = \ self.scheduler_config.max_num_batched_tokens self.block_size = self.cache_config.block_size + self.use_merged_prefill = VLLM_MERGED_PREFILL + assert not (self.scheduler_config.use_padding_aware_scheduling + and self.use_merged_prefill), \ + 'Merged prefill is not compatible with padding aware scheduling!' self.pin_memory = is_pin_memory_available() self.kv_cache_dtype = self.cache_config.cache_dtype @@ -769,12 +749,15 @@ def __init__( self.profiler_counter_helper = HabanaProfilerCounterHelper() self.seen_configs: set = set() self._mem_margin: Optional[int] = None + HPUBucketingContext = get_bucketing_context() self.bucketing_ctx = HPUBucketingContext(self.max_num_seqs, self.max_num_prefill_seqs, self.block_size, - self.max_num_batched_tokens, False) + self.max_num_batched_tokens, + self.use_merged_prefill, + self.max_model_len) self.graphed_buckets: Set[Any] = set() - self.multimodal_buckets = [] #This should be use HPUBucketingContext + self.multimodal_buckets = [] #This should be use HPUBucketingContext self.graphed_multimodal_buckets: Set[Any] = set() self._set_gc_threshold() @@ -925,6 +908,12 @@ def load_model(self) -> None: layer_names=path_to_rope) msg = f"Wrapping in HPU Graph took {m_wrap.get_summary_string()}" logger.info(msg) + with HabanaMemoryProfiler() as m_wrap: + self._maybe_compile(self.model, + vllm_config=self.vllm_config, + layer_names=path_to_rope) + msg = f"Compiling took {m_wrap.get_summary_string()}" + logger.info(msg) self.model_memory_usage = m.consumed_device_memory msg = f"Loading model weights took in total {m.get_summary_string()}" @@ -962,6 +951,65 @@ def _maybe_wrap_in_hpu_graph(self, *args, **kwargs): else: return HpuModelAdapter(*args, **kwargs) + def _maybe_compile(self, *args, **kwargs): + if not is_fake_hpu() and not htorch.utils.internal.is_lazy( + ) and not self.vllm_config.model_config.enforce_eager: + fullgraph = os.getenv('VLLM_T_COMPILE_FULLGRAPH', + 'false').strip().lower() in ("1", "true") + if os.getenv('VLLM_REGIONAL_COMPILATION', + 'true').strip().lower() in ("1", "true"): + compiled_methods = [self.model._set_block_mapping] + for method in compiled_methods: + method = torch.compile(method, + backend='hpu_backend', + fullgraph=fullgraph, + dynamic=False) + self.regional_compilation_layers_list = [ + RMSNorm, VocabParallelEmbedding + ] + self._regional_compilation(self.model, fullgraph) + else: + self.model = torch.compile(self.model, + backend='hpu_backend', + fullgraph=fullgraph, + dynamic=False) + + def _regional_compilation(self, + module, + fullgraph, + parent_module=None, + module_name=None): + if isinstance(module, torch.nn.ModuleList): + for children_name, children_module in module.named_children(): + self._compile_region(module, fullgraph, children_name, + children_module) + elif any( + isinstance(module, layer) + for layer in self.regional_compilation_layers_list): + self._compile_region( + parent_module, + fullgraph, + module_name, + module, + ) + else: + for children_name, children_module in module.named_children(): + self._regional_compilation(children_module, fullgraph, module, + children_name) + + def _compile_region( + self, + model, + fullgraph, + name, + module, + ): + module = torch.compile(module, + backend='hpu_backend', + fullgraph=fullgraph, + dynamic=False) + setattr(model, name, module) + def get_model(self) -> torch.nn.Module: if isinstance(self.model, HpuModelAdapter): return self.model.model @@ -1043,6 +1091,34 @@ def _get_mrope_positions_and_delta(self, seq_data, mm_kwargs, context_len): assert mrope_positions is not None return mrope_positions, mrope_position_delta + def make_attn_bias(self, seq_lens, max_prompt_len, dtype): + seq_pos = [list(range(sl)) for sl in seq_lens] + seq_idx = [[i] * sl for i, sl in enumerate(seq_lens)] + seq_pos_t = make_cpu_tensor(seq_pos, + max_len=max_prompt_len, + pad=-1, + dtype=torch.long, + flat=self.use_merged_prefill) + seq_idx_t = make_cpu_tensor(seq_idx, + max_len=max_prompt_len, + pad=-1, + dtype=torch.long, + flat=self.use_merged_prefill) + q_seq_idx_t = seq_idx_t.unsqueeze(-1) + kv_seq_idx_t = seq_idx_t.unsqueeze(-2) + q_seq_pos_t = seq_pos_t.unsqueeze(-1) + kv_seq_pos_t = seq_pos_t.unsqueeze(-2) + seq_idx_t = q_seq_idx_t != kv_seq_idx_t + seq_pos_t = kv_seq_pos_t > q_seq_pos_t + attn_mask = seq_idx_t | seq_pos_t + attn_bias = torch.zeros_like(attn_mask, dtype=dtype) + attn_bias.masked_fill_(attn_mask, -math.inf) + return attn_bias.unsqueeze(1) + + def move_to_device(self, tensor): + return tensor if tensor is None else tensor.to(self.device, + non_blocking=True) + def _prepare_prompt( self, seq_group_metadata_list: List[SequenceGroupMetadata], @@ -1096,6 +1172,10 @@ def _prepare_prompt( computed_block_nums) > 0 and self.sliding_window is None: # Prefix is not supported with sliding_window context_len = len(computed_block_nums) * self.block_size + if context_len == seq_len \ + and self.vllm_config.cache_config.enable_prefix_caching: + # Fully cached prompt - compute only last token + context_len = context_len - 1 prompt_tokens = prompt_tokens[context_len:] prefix_block_tables.append(computed_block_nums) elif self.scheduler_config.chunked_prefill_enabled: @@ -1135,9 +1215,6 @@ def _prepare_prompt( seq_group_metadata.mm_processor_kwargs, ) - # padding image patches (pixel_values, image_grid_thw) - #mm_kwargs = pad_multimodal_data(mm_kwargs) - # special processing for mrope position deltas. if self.model_is_mrope: mrope_positions, mrope_position_delta = \ @@ -1194,13 +1271,14 @@ def _prepare_prompt( slot = block_number * self.block_size + block_offset slot_mapping[-1].append(slot) - max_query_len = max(query_lens) + if self.use_merged_prefill: + target_query_len = sum(query_lens) + else: + target_query_len = max(query_lens) real_num_seqs = len(query_lens) - assert max_query_len > 0 - max_prompt_len = max( - self.bucketing_ctx.get_padded_prompt_seq_len(max_query_len), + self.bucketing_ctx.get_padded_prompt_seq_len(target_query_len), self.block_size) lora_ids: List[int] = [] @@ -1240,34 +1318,40 @@ def _prepare_prompt( else: prefix_block_list_tensor = None - input_tokens_tensor = make_tensor_with_pad(input_tokens, - max_len=max_prompt_len, - pad=0, - dtype=torch.long, - device='cpu') - + input_tokens_tensor = make_cpu_tensor(input_tokens, + max_len=max_prompt_len, + pad=0, + dtype=torch.long, + flat=self.use_merged_prefill) if self.model_is_mrope: input_positions = \ make_mrope_positions_tensor_with_pad(input_positions=input_positions, - input_mrope_positions=input_mrope_positions, - max_prompt_len=max_prompt_len, - pad=0) + input_mrope_positions=input_mrope_positions, + max_prompt_len=max_prompt_len, + pad=0) else: - input_positions = make_tensor_with_pad(input_positions, - max_len=max_prompt_len, - pad=0, - dtype=torch.long, - device='cpu') + input_positions = make_cpu_tensor(input_positions, + max_len=max_prompt_len, + pad=0, + dtype=torch.long, + flat=self.use_merged_prefill) + + slot_mapping = make_cpu_tensor(slot_mapping, + max_len=max_prompt_len, + pad=_PAD_SLOT_ID, + dtype=torch.long, + flat=self.use_merged_prefill) - slot_mapping = make_tensor_with_pad(slot_mapping, - max_len=max_prompt_len, - pad=_PAD_SLOT_ID, - dtype=torch.long, - device='cpu') + attn_bias = None + seq_lens_tensor = None + context_lens_tensor = None + + if self.use_merged_prefill: + attn_bias = self.make_attn_bias(seq_lens, max_prompt_len, + self.model_config.dtype) seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.long, device='cpu') - context_lens_tensor = torch.tensor(context_lens, dtype=torch.long, device='cpu') @@ -1281,18 +1365,15 @@ def _prepare_prompt( # Note: num_prefill_tokens is calculated using the length of # input_tokens after padding. num_prefill_tokens = input_tokens_tensor.numel() - if prefix_block_list_tensor is not None: - prefix_block_list_tensor = prefix_block_list_tensor.to( - self.device, non_blocking=True) - input_tokens_tensor = input_tokens_tensor.to( # type: ignore - self.device, non_blocking=True) - input_positions = input_positions.to( # type: ignore - self.device, non_blocking=True) - slot_mapping = slot_mapping.to( # type: ignore - self.device, non_blocking=True) - seq_lens_tensor = seq_lens_tensor.to(self.device, non_blocking=True) - context_lens_tensor = context_lens_tensor.to(self.device, - non_blocking=True) + + prefix_block_list_tensor = self.move_to_device( + prefix_block_list_tensor) + input_tokens_tensor = self.move_to_device(input_tokens_tensor) + input_positions = self.move_to_device(input_positions) + seq_lens_tensor = self.move_to_device(seq_lens_tensor) + slot_mapping = self.move_to_device(slot_mapping) + context_lens_tensor = self.move_to_device(context_lens_tensor) + attn_bias = self.move_to_device(attn_bias) attn_metadata = self.attn_backend.make_metadata( is_prompt=True, @@ -1301,11 +1382,10 @@ def _prepare_prompt( block_usage=None, block_indices=None, block_offsets=None, - block_scales=None, block_groups=None, - attn_bias=None, + attn_bias=attn_bias, seq_lens=seq_lens, - seq_lens_tensor=seq_lens_tensor, + seq_lens_tensor=self.move_to_device(seq_lens_tensor), context_lens_tensor=context_lens_tensor, num_prefills=real_num_seqs, num_prefill_tokens=num_prefill_tokens, @@ -1585,7 +1665,6 @@ def _prepare_decode( block_usage=block_usage, block_indices=None, block_offsets=None, - block_scales=None, block_groups=block_groups, attn_bias=None, seq_lens_tensor=None, @@ -1614,6 +1693,7 @@ def _prepare_decode( def prepare_input_tensors( self, seq_group_metadata_list: List[SequenceGroupMetadata], + finished_requests_ids: Optional[List[str]] = None ) -> Tuple[TModelInputForHPU, SamplingMetadata]: if len(seq_group_metadata_list) == 0: return self._model_input_cls(), None @@ -1671,9 +1751,14 @@ def prepare_input_tensors( ) = self._prepare_decode(decode_reqs) if not self.is_pooler: + generators = self.get_generators(finished_requests_ids) sampling_metadata = SamplingMetadata.prepare( - seq_group_metadata_list, seq_lens, query_lens, self.device, - self.pin_memory) + seq_group_metadata_list, + seq_lens, + query_lens, + self.device, + self.pin_memory, + generators=generators) if not self.scheduler_config.chunked_prefill_enabled: assert (len(prefill_reqs) and len(decode_reqs)) == 0 @@ -1697,29 +1782,32 @@ def prepare_input_tensors( lora_requests = decode_lora_requests lora_ids = decode_lora_ids - # FIXME: We need to adjust selected_token_indices to accommodate - # for padding - max_len = input_tokens.size(1) - paddings = [max_len - q for q in query_lens] - paddings = [0] + paddings[:-1] - paddings = list(itertools.accumulate(paddings)) - paddings_prompt_logprobs = [] - - if not self.is_pooler: - for i, seq_group_metadata in enumerate(seq_group_metadata_list): - if seq_group_metadata.sampling_params \ - and seq_group_metadata.sampling_params.prompt_logprobs \ - is not None and seq_group_metadata.is_prompt: - paddings_prompt_logprobs += ([paddings[i]] * seq_lens[i]) - - paddings = torch.tensor( - paddings_prompt_logprobs - if paddings_prompt_logprobs else paddings, - dtype=sampling_metadata.selected_token_indices.dtype, - device=sampling_metadata.selected_token_indices.device) - sampling_metadata.selected_token_indices.add_(paddings) - else: + if self.is_pooler: sampling_metadata = None + elif not self.use_merged_prefill: + # FIXME: We need to adjust selected_token_indices to accommodate + # for padding + max_len = input_tokens.size(1) + paddings = [max_len - q for q in query_lens] + paddings = [0] + paddings[:-1] + paddings = list(itertools.accumulate(paddings)) + paddings_prompt_logprobs = [] + + if not self.is_pooler: + for i, seq_group_metadata in enumerate( + seq_group_metadata_list): + if seq_group_metadata.sampling_params \ + and seq_group_metadata.sampling_params.prompt_logprobs \ + is not None and seq_group_metadata.is_prompt: + paddings_prompt_logprobs += ([paddings[i]] * + seq_lens[i]) + + paddings = torch.tensor( + paddings_prompt_logprobs + if paddings_prompt_logprobs else paddings, + dtype=sampling_metadata.selected_token_indices.dtype, + device=sampling_metadata.selected_token_indices.device) + sampling_metadata.selected_token_indices.add_(paddings) if self.lora_config: lora_mapping = LoRAMapping( @@ -1789,6 +1877,76 @@ def prepare_input_tensors( lora_ids=lora_ids), \ sampling_metadata + def create_lora_mask(self, input_tokens: torch.Tensor, lora_ids: List[int], + is_prompt: bool): + ''' + This is a helper function to create the mask for lora computations. + Lora Mask is needed to ensure we match the correct lora weights for the + for the request. + For Prompt phase we have + lora_mask with shape (batch_size * seq_len, max_loras * max_rank) + lora_logits_mask with shape (batch_size, max_loras * max_rank) + For Decode phase we have both + lora_mask and lora_logits_mask with shape + (batch_size, max_loras * max_rank) + ''' + lora_mask: torch.Tensor = None + lora_logits_mask: torch.Tensor = None + lora_index = 0 + + if self.lora_config: + if is_prompt: + lora_mask = torch.zeros( + input_tokens.shape[0] * input_tokens.shape[1], + (self.lora_config.max_loras) *\ + self.lora_config.max_lora_rank, + dtype=self.lora_config.lora_dtype) + lora_logits_mask = torch.zeros( + input_tokens.shape[0], (self.lora_config.max_loras) * + self.lora_config.max_lora_rank, + dtype=self.lora_config.lora_dtype) + + ones = torch.ones(input_tokens.shape[1], + self.lora_config.max_lora_rank, + dtype=self.lora_config.lora_dtype) + logit_ones = torch.ones(1, + self.lora_config.max_lora_rank, + dtype=self.lora_config.lora_dtype) + + for i in range(len(lora_ids)): + if lora_ids[i] == 0: + continue + lora_index = self.lora_manager._adapter_manager.\ + lora_index_to_id.index(lora_ids[i]) + start_row = i * input_tokens.shape[1] + end_row = start_row + input_tokens.shape[1] + start_col = lora_index * self.lora_config.max_lora_rank + end_col = start_col + self.lora_config.max_lora_rank + lora_mask[start_row:end_row, start_col:end_col] = ones + lora_logits_mask[i, start_col:end_col] = logit_ones + lora_mask = lora_mask.to('hpu') + lora_logits_mask = lora_logits_mask.to('hpu') + else: + lora_mask = torch.zeros(input_tokens.shape[0], + (self.lora_config.max_loras) * + self.lora_config.max_lora_rank, + dtype=self.lora_config.lora_dtype) + ones = torch.ones(1, + self.lora_config.max_lora_rank, + dtype=self.lora_config.lora_dtype) + for i in range(len(lora_ids)): + if lora_ids[i] == 0: + continue + lora_index = self.lora_manager._adapter_manager.\ + lora_index_to_id.index(lora_ids[i]) + start_pos = lora_index * self.lora_config.max_lora_rank + end_pos = start_pos + self.lora_config.max_lora_rank + lora_mask[i, start_pos:end_pos] = ones + lora_mask = lora_mask.to('hpu') + lora_logits_mask = lora_mask + + return lora_mask, lora_logits_mask + def _seq_len(self, attn_metadata): if attn_metadata.num_prefills != 0: return attn_metadata.slot_mapping.size(1) @@ -1827,7 +1985,6 @@ def trim_attn_metadata(self, metadata: AttentionMetadata) -> object: 'is_prompt', 'block_indices', 'block_offsets', - 'block_scales', 'block_groups', ]) return attention_metadata @@ -1896,7 +2053,6 @@ def create_dummy_seq_group_metadata(self, sampling_params = SamplingParams(temperature=temperature) num_blocks = math.ceil(seq_len / self.block_size) seq_len = max(seq_len, 1) - if is_prompt: input_len = seq_len output_len = 0 @@ -1963,7 +2119,6 @@ def warmup_scenario(self, # that will have unique loras, an therefore the max amount of memory # consumption create dummy lora request copies from the lora request # passed in, which contains a lora from the lora warmup path. - dummy_lora_requests: List[LoRARequest] = [] dummy_lora_requests_per_seq: List[LoRARequest] = [] if self.lora_config and is_lora_profile_run: @@ -2034,7 +2189,18 @@ def warmup_scenario(self, is_single_step = \ self.vllm_config.scheduler_config.num_scheduler_steps == 1 if is_prompt or is_single_step: - self.execute_model(inputs, kv_caches, warmup_mode=True) + intermediate_tensors = None + if not get_pp_group().is_first_rank: + intermediate_tensors = \ + self.model.make_empty_intermediate_tensors( + batch_size=batch_size, + context_size=seq_len if is_prompt else 1, + dtype=self.model_config.dtype, + device=self.device) + self.execute_model(inputs, + kv_caches, + intermediate_tensors=intermediate_tensors, + warmup_mode=True) else: # decode with multi-step inputs = dataclasses.replace(inputs, is_first_multi_step=True, @@ -2047,7 +2213,6 @@ def warmup_scenario(self, inputs = dataclasses.replace(inputs, is_first_multi_step=False, is_last_step=True) - # TODO: why 2 execute_model? self.execute_model(inputs, kv_caches, warmup_mode=True, @@ -2215,7 +2380,7 @@ def warmup_graphs(self, if batch_size not in warmed_random_sampler_bs else 0, multimodal_seqs_group_metada=True if is_prompt else False, height=112, - width=112) # everythign must be 112 aligned (for now) + width=112) warmed_random_sampler_bs.add(batch_size) used_mem = align_workers(mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX) @@ -2250,7 +2415,6 @@ def warmup_graphs(self, return_time=True) self.visual_warmup_times[((h*w) / (14*14))] = self.visual_warmup_times.get(((h*w) / (14*14)), []) + [('graph', t)] - return total_mem, total_batch_seq, captured_all def log_graph_warmup_summary(self, buckets, is_prompt, total_mem): @@ -2268,6 +2432,15 @@ def log_graph_warmup_summary(self, buckets, is_prompt, total_mem): @torch.inference_mode() def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: + if not self.is_pooler: + max_blocks = kv_caches[0][0].size(0) + self.bucketing_ctx.generate_decode_buckets(max_blocks) + + if supports_multimodal(self.model.model): + FIXED_MULTIMODAL_BUCKETS = self.model.model.FIXED_MULTIMODAL_BUCKETS + self.multimodal_buckets = [[112, total_size * 14 * 14 / 112] for total_size in FIXED_MULTIMODAL_BUCKETS] + print("Multimodal bucket :", self.multimodal_buckets) + if profile := os.environ.get('VLLM_PT_PROFILE', None): phase, bs, seq_len, graph = profile.split('_') is_prompt = phase == 'prompt' @@ -2277,35 +2450,6 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: self.warmup_scenario(int(bs), int(seq_len), is_prompt, kv_caches, True) raise AssertionError("Finished profiling") - if not self.is_pooler: - max_blocks = kv_caches[0][0].size(0) - self.bucketing_ctx.generate_prompt_buckets() - - if supports_multimodal(self.model.model): - if True: - FIXED_MULTIMODAL_BUCKETS = self.model.model.FIXED_MULTIMODAL_BUCKETS - # [1600, 3200, 4800, 6400, 9600] - # 1600 means an image with 1600*14*14 ppixels, ie 560x560 - self.multimodal_buckets = [[112, total_size * 14 * 14 / 112] for total_size in FIXED_MULTIMODAL_BUCKETS] - # TODO This is qwen2.5vl/model specific code here. This should come from model file? - else: - #TODO: - # Multimodal buckets are based on H,W , it should be changed to be aligned with multimodal paddings. - # Also need to move to HPUBucketingContext. - #Multimodal bucket : [[560, 560], [560, 1120], [560, 1680], [1120, 560], [1120, 1120], [1120, 1680], [1680, 560], [1680, 1120], [1680, 1680]] - VLLM_MULTIMODAL_BUCKET = 560 #Pick number divisible by 28(patchsize*mergesize), this can be env. - max_seq_len = 1120 #2048 #self.max_num_batched_tokens - bucket = VLLM_MULTIMODAL_BUCKET - # TODO this number and self.FIXED_MULTIMODAL_BUCKETS should be in sync - self.multimodal_buckets = [ - [h, w] for h in range(bucket, max_seq_len + 1, bucket) - for w in range(bucket, max_seq_len + 1, bucket) - ] - breakpoint() - print("Multimodal bucket :", self.multimodal_buckets) - - if not self.is_pooler: - self.bucketing_ctx.generate_decode_buckets(max_blocks) if not htorch.utils.internal.is_lazy() and not self.enforce_eager: multiplier = 3 if os.getenv('VLLM_REGIONAL_COMPILATION', 'true').lower() == 'true' else 1 @@ -2339,7 +2483,6 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: logger.warning('Cannot use PT_COMPILE_ONLY_MODE. ' 'Warmup time will be negatively impacted. ' 'Please update Gaudi Software Suite.') - with compile_only_mode_context( ) if can_use_compile_only_mode else contextlib.nullcontext(): self.warmup_all_buckets(self.bucketing_ctx.prompt_buckets, True, @@ -2347,7 +2490,7 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: if not self.is_pooler: self.warmup_all_buckets(self.bucketing_ctx.decode_buckets, False, kv_caches) - #breakpoint() # self.visual_warmup_times is populated at this point (but without hpu graphs) + if not self.enforce_eager and htorch.utils.internal.is_lazy(): if not self.is_pooler: assert self.mem_margin is not None, \ @@ -2453,15 +2596,10 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: f"Warmup finished in {elapsed_time:.0f} secs, " f"allocated {format_bytes(end_mem - start_mem)} of device memory") logger.info(msg) - #breakpoint() # self.visual_warmup_times is populated at this point (with and without hpu graphs) - # Another way to do this is for qwen2.5vl model (or any multimodal model) to track if a new shape is incoming, and then enable a timer. - # then this "time collection" logic is hidden in model file itself, and model_runner isnt tainted with it - # Also inside the model file, we may get a better estimate of the time. right now the time is a proxy as it also contains "text time" (though all text inp is of same len (2048)?) - # bt we'd need to markstep/sync=True if we are collecting times inside if hasattr(self, 'visual_warmup_times'): summary = {k: min([t for _, t in self.visual_warmup_times[k]]) for k in self.visual_warmup_times} self.visual_warmup_times = summary - self.model.visual_warmup_times = self.visual_warmup_times # a strange way to pass this in, but gonna charge ahead for now with this + self.model.visual_warmup_times = self.visual_warmup_times self.profiler.end() def finish_measurements(self): @@ -2620,7 +2758,7 @@ def prepare_model_input( self.profiler_counter_helper.capture_seq_group_metadata_stats( seq_group_metadata_list=seq_group_metadata_list) model_input, sampling_metadata = self.prepare_input_tensors( - seq_group_metadata_list) + seq_group_metadata_list, finished_requests_ids) assert model_input.attn_metadata is not None is_prompt = model_input.attn_metadata.is_prompt @@ -2629,76 +2767,6 @@ def prepare_model_input( is_prompt=is_prompt, virtual_engine=virtual_engine) - def create_lora_mask(self, input_tokens: torch.Tensor, lora_ids: List[int], - is_prompt: bool): - ''' - This is a helper function to create the mask for lora computations. - Lora Mask is needed to ensure we match the correct lora weights for the - for the request. - For Prompt phase we have - lora_mask with shape (batch_size * seq_len, max_loras * max_rank) - lora_logits_mask with shape (batch_size, max_loras * max_rank) - For Decode phase we have both - lora_mask and lora_logits_mask with shape - (batch_size, max_loras * max_rank) - ''' - lora_mask: torch.Tensor = None - lora_logits_mask: torch.Tensor = None - lora_index = 0 - - if self.lora_config: - if is_prompt: - lora_mask = torch.zeros( - input_tokens.shape[0] * input_tokens.shape[1], - (self.lora_config.max_loras) *\ - self.lora_config.max_lora_rank, - dtype=self.lora_config.lora_dtype) - lora_logits_mask = torch.zeros( - input_tokens.shape[0], (self.lora_config.max_loras) * - self.lora_config.max_lora_rank, - dtype=self.lora_config.lora_dtype) - - ones = torch.ones(input_tokens.shape[1], - self.lora_config.max_lora_rank, - dtype=self.lora_config.lora_dtype) - logit_ones = torch.ones(1, - self.lora_config.max_lora_rank, - dtype=self.lora_config.lora_dtype) - - for i in range(len(lora_ids)): - if lora_ids[i] == 0: - continue - lora_index = self.lora_manager._adapter_manager.\ - lora_index_to_id.index(lora_ids[i]) - start_row = i * input_tokens.shape[1] - end_row = start_row + input_tokens.shape[1] - start_col = lora_index * self.lora_config.max_lora_rank - end_col = start_col + self.lora_config.max_lora_rank - lora_mask[start_row:end_row, start_col:end_col] = ones - lora_logits_mask[i, start_col:end_col] = logit_ones - lora_mask = lora_mask.to('hpu') - lora_logits_mask = lora_logits_mask.to('hpu') - else: - lora_mask = torch.zeros(input_tokens.shape[0], - (self.lora_config.max_loras) * - self.lora_config.max_lora_rank, - dtype=self.lora_config.lora_dtype) - ones = torch.ones(1, - self.lora_config.max_lora_rank, - dtype=self.lora_config.lora_dtype) - for i in range(len(lora_ids)): - if lora_ids[i] == 0: - continue - lora_index = self.lora_manager._adapter_manager.\ - lora_index_to_id.index(lora_ids[i]) - start_pos = lora_index * self.lora_config.max_lora_rank - end_pos = start_pos + self.lora_config.max_lora_rank - lora_mask[i, start_pos:end_pos] = ones - lora_mask = lora_mask.to('hpu') - lora_logits_mask = lora_mask - - return lora_mask, lora_logits_mask - def _get_seq_ids(self, model_input): return ([ sg.seq_ids[0] for sg in model_input.sampling_metadata.seq_groups @@ -2728,6 +2796,9 @@ def execute_model( use_delayed_sampling = VLLM_DELAYED_SAMPLING and not warmup_mode assert not (use_delayed_sampling and num_steps != 1), \ 'Delayed sampling is not compatible with MSS!' + assert not (use_delayed_sampling and + self.parallel_config.pipeline_parallel_size != 1), \ + 'Delayed sampling is not compatible with Pipeline Parallelism!' assert model_input.input_tokens is not None if use_delayed_sampling and not model_input.is_prompt and \ self.is_driver_worker: @@ -2809,7 +2880,6 @@ def execute_model( lora_mask, lora_logits_mask = self.create_lora_mask( input_tokens, model_input.lora_ids, attn_metadata.is_prompt) - execute_model_kwargs = { "input_ids": input_tokens, "positions": input_positions, @@ -2877,7 +2947,6 @@ def try_revert_dummy_output_tokens(): with self.profiler.record_event('internal', model_event_name, args=profiler_args): - hidden_states = self.model.forward( **execute_model_kwargs, selected_token_indices=sampling_metadata. @@ -2887,6 +2956,8 @@ def try_revert_dummy_output_tokens(): LoraMask.setLoraMask( lora_logits_mask.index_select( 0, sampling_metadata.selected_token_indices)) + if not get_pp_group().is_last_rank: + return hidden_states # Compute the logits. with self.profiler.record_event( @@ -2907,6 +2978,8 @@ def try_revert_dummy_output_tokens(): if use_delayed_sampling: fake_output = self._delayed_sampler_outputs(model_input) + elif model_input.async_callback is not None: + model_input.async_callback() with self.profiler.record_event( 'internal', ('sample_' @@ -2928,7 +3001,8 @@ def try_revert_dummy_output_tokens(): self.cached_step_outputs.append(output) self.cached_step_inputs.append(model_input) htorch.core.mark_step() - if model_input.async_callback is not None: + if use_delayed_sampling \ + and model_input.async_callback is not None: model_input.async_callback() if i < num_steps - 1: if i == 0: @@ -3029,18 +3103,10 @@ def try_revert_dummy_output_tokens(): else: return [] - #from habana_frameworks.torch.hpu.metrics import metric_global - #gc_metric = metric_global("graph_compilation") - #print(" -------- graph_compilation: ", gc_metric.stats()) - return [output] if self.is_driver_worker else [] else: return [] - #from habana_frameworks.torch.hpu.metrics import metric_global - #gc_metric = metric_global("graph_compilation") - #print(" -------- graph_compilation: ", gc_metric.stats()) - return output if type(output) is list else [output] def _delayed_sampler_outputs(self, model_input): From e8d4c3e79713644fc8c4e0e4f6c88d81c4f8b08a Mon Sep 17 00:00:00 2001 From: Gustavo Malkomes Date: Mon, 14 Apr 2025 22:38:51 +0000 Subject: [PATCH 05/38] using max_pixels instead of h,w --- vllm/worker/hpu_model_runner.py | 55 ++++++++++++++------------------- 1 file changed, 24 insertions(+), 31 deletions(-) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index db32b1da2f46..881704aa4017 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -1025,7 +1025,6 @@ def _use_graphs(self, batch_size, seq_len, is_prompt, is_multimodal): else: #TODO:For now return TRUE for development #This needs to be updated later with proper bucket detections. - #return (batch_size, height, weight) in self.graphed_multimodal_buckets return True def _is_valid_bucket(self, bucket): @@ -1991,7 +1990,7 @@ def trim_attn_metadata(self, metadata: AttentionMetadata) -> object: def create_dummy_multi_modal_seq_group_metadata(self, group_id, seq_len, lora_request, temperature, - height, width): + max_pixels): from vllm.multimodal.utils import cached_get_tokenizer @@ -2009,15 +2008,15 @@ def create_dummy_multi_modal_seq_group_metadata(self, group_id, seq_len, tokenizer) mm_counts = self.mm_registry.get_mm_limits_per_prompt( self.model_config) - #mm_counts = {"image":1} - print("mm_counts:", mm_counts) factory = processor.dummy_inputs processor_inputs = factory.get_dummy_processor_inputs( seq_len=seq_len, mm_counts=mm_counts, - image_width=width, - image_height=height) - + ) + if max_pixels is not None: + # Note: We will overwrite this value if any exits + processor_inputs.hf_processor_mm_kwargs["max_pixels"] = max_pixels + print(" ============ ", processor_inputs.hf_processor_mm_kwargs) mm_inputs = processor.apply( prompt=processor_inputs.prompt_text, mm_data=processor_inputs.mm_data, @@ -2030,6 +2029,7 @@ def create_dummy_multi_modal_seq_group_metadata(self, group_id, seq_len, prompt_token_ids.extend([0] * (seq_len - len(prompt_token_ids))) seq_data = SequenceData.from_seqs(prompt_token_ids) + import pdb; pdb.set_trace() return SequenceGroupMetadata( request_id=str(group_id), is_prompt=True, @@ -2105,8 +2105,7 @@ def warmup_scenario(self, is_lora_profile_run=False, multimodal_seqs_group_metada=False, temperature=0, - height=None, - width=None, + max_pixels=None, return_time=False) -> None: use_graphs = self._use_graphs(batch_size, seq_len, is_prompt, multimodal_seqs_group_metada) @@ -2151,8 +2150,8 @@ def warmup_scenario(self, lora_request=dummy_lora_requests_per_seq[i] if dummy_lora_requests_per_seq else None, temperature=temperature, - height=height, - width=width) for i in range(batch_size) + max_pixels=max_pixels, + ) for i in range(batch_size) ] elif is_prompt: seqs = [ @@ -2273,7 +2272,7 @@ def log_warmup(self, phase, i, max_i, batch_size, seq_len): logger.info(msg) def log_warmup_multimodal(self, phase, i, max_i, batch_size, seq_len, - height, width): + max_pixels): free_mem = format_bytes( HabanaMemoryProfiler.current_free_device_memory()) dim = "num_blocks" @@ -2281,7 +2280,7 @@ def log_warmup_multimodal(self, phase, i, max_i, batch_size, seq_len, dim = "seq_len" msg = (f"[Warmup][{phase}][{i+1}/{max_i}] " f"batch_size:{batch_size} " - f"{dim}:{seq_len}", f"hw:({height},{width})", + f"{dim}:{seq_len}", f"max_pixels:{max_pixels}", f"free_mem:{free_mem}") logger.info(msg) @@ -2295,12 +2294,11 @@ def warmup_all_buckets(self, buckets, is_prompt, kv_caches): # Warmup Multimodal with fixed seq_len if not hasattr(self, 'visual_warmup_times'): self.visual_warmup_times = {} - for i, (h, w) in enumerate(self.multimodal_buckets): + for i, max_pixels in enumerate(self.multimodal_buckets): max_batch_size = 1 #TODO: For now we hardcoded batch 1. max_seq_len = 2048 #TODO: set with VLLM_PROMPT_SEQ_BUCKET_MAX self.log_warmup_multimodal('Image', i, max_seq_len, max_batch_size, - max_seq_len, h, w) - assert h%112 == 0 and w % 112 == 0, "Expected to be 112 aligned for now" + max_seq_len, max_pixels) t = self.warmup_scenario(batch_size=max_batch_size, seq_len=max_seq_len, is_prompt=True, @@ -2308,8 +2306,7 @@ def warmup_all_buckets(self, buckets, is_prompt, kv_caches): is_pt_profiler_run=False, is_lora_profile_run=True, multimodal_seqs_group_metada=True, - height=h, - width=w, + max_pixels=max_pixels, return_time=True) #if ((h*w) / (14*14)) in self.visual_warmup_times: #breakpoint() @@ -2329,8 +2326,7 @@ def warmup_all_buckets(self, buckets, is_prompt, kv_caches): is_prompt, kv_caches, multimodal_seqs_group_metada=True, - height=112, - width=112) # everythign must be 112 aligned (for now) + max_pixels=max_pixels) # everythign must be 112 aligned (for now) else: self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches) @@ -2379,8 +2375,7 @@ def warmup_graphs(self, temperature=1.0 if batch_size not in warmed_random_sampler_bs else 0, multimodal_seqs_group_metada=True if is_prompt else False, - height=112, - width=112) + max_pixels=14*14*112) warmed_random_sampler_bs.add(batch_size) used_mem = align_workers(mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX) @@ -2410,8 +2405,7 @@ def warmup_graphs(self, #is_pt_profiler_run=False, #is_lora_profile_run=True, multimodal_seqs_group_metada=True, - height=h, - width=w, + max_pixels=max_pixels, return_time=True) self.visual_warmup_times[((h*w) / (14*14))] = self.visual_warmup_times.get(((h*w) / (14*14)), []) + [('graph', t)] @@ -2435,10 +2429,9 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: if not self.is_pooler: max_blocks = kv_caches[0][0].size(0) self.bucketing_ctx.generate_decode_buckets(max_blocks) - + if supports_multimodal(self.model.model): - FIXED_MULTIMODAL_BUCKETS = self.model.model.FIXED_MULTIMODAL_BUCKETS - self.multimodal_buckets = [[112, total_size * 14 * 14 / 112] for total_size in FIXED_MULTIMODAL_BUCKETS] + self.multimodal_buckets = self.model.model.FIXED_MULTIMODAL_BUCKETS print("Multimodal bucket :", self.multimodal_buckets) if profile := os.environ.get('VLLM_PT_PROFILE', None): @@ -2596,10 +2589,10 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: f"Warmup finished in {elapsed_time:.0f} secs, " f"allocated {format_bytes(end_mem - start_mem)} of device memory") logger.info(msg) - if hasattr(self, 'visual_warmup_times'): - summary = {k: min([t for _, t in self.visual_warmup_times[k]]) for k in self.visual_warmup_times} - self.visual_warmup_times = summary - self.model.visual_warmup_times = self.visual_warmup_times + # if hasattr(self, 'visual_warmup_times'): + # summary = {k: min([t for _, t in self.visual_warmup_times[k]]) for k in self.visual_warmup_times} + # self.visual_warmup_times = summary + # self.model.visual_warmup_times = self.visual_warmup_times self.profiler.end() def finish_measurements(self): From 86e65fbd716457c1c949e120f8029963ee34af11 Mon Sep 17 00:00:00 2001 From: Sayantan Sarkar Date: Tue, 15 Apr 2025 00:50:13 +0000 Subject: [PATCH 06/38] clean up if/else --- vllm/model_executor/models/qwen2_5_vl.py | 18 +++--------------- 1 file changed, 3 insertions(+), 15 deletions(-) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index ae43a5d77628..c21f907c5933 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -1061,7 +1061,7 @@ def _process_image_input( image_embeds = image_input["image_embeds"].type(self.visual.dtype) else: - if True: + if is_hpu: ''' go thru grid_thw say grid_thw is 1,16,16 and 1,128,128 @@ -1108,20 +1108,8 @@ def _process_image_input( results_cat = torch.concat(results) image_embeds = results_cat else: - pixel_values, rot_pos_emb, cu_seqlens, cu_window_seqlens, window_index = self.visual.pre_attn( - pixel_values, grid_thw) - assert pixel_values.shape[0] % 64 == 0, f"We need image h/w to be aligned to 112 for now. Which will make pixel_values be a multiple of (112/14)*(112/14)=64 (14 is patch size for ViT). Got pixel_values shape {pixel_values.shape[0]}" - #print('.......', cu_seqlens, expand_to_max(cu_seqlens, 10)) - expanded_cu_seqlens = expand_to_max(cu_seqlens, 10) - htcore.mark_step() # padding in expand_to_max is dynamic - #breakpoint() - hidden_states = self.visual(pixel_values, - rotary_pos_emb=rot_pos_emb, - cu_seqlens=expanded_cu_seqlens,) - #cu_window_seqlens=cu_window_seqlens) - htcore.mark_step() - image_embeds = self.visual.post_attn(hidden_states, window_index) - image_embeds = self.visual(pixel_values, grid_thw=grid_thw) + pixel_values = image_input["pixel_values"].type(self.visual.dtype) + image_embeds = self.visual(pixel_values, grid_thw=grid_thw) # Split concatenated embeddings for each image item. merge_size = self.visual.spatial_merge_size From 4037e043ecb089f174f3a1ffd11761d9557ca966 Mon Sep 17 00:00:00 2001 From: Sayantan Sarkar Date: Tue, 15 Apr 2025 01:28:46 +0000 Subject: [PATCH 07/38] clean up if-else 2 --- vllm/model_executor/models/qwen2_5_vl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index c21f907c5933..2600e0af5e39 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -1060,6 +1060,7 @@ def _process_image_input( if image_input["type"] == "image_embeds": image_embeds = image_input["image_embeds"].type(self.visual.dtype) else: + pixel_values = image_input["pixel_values"].type(self.visual.dtype) if is_hpu: ''' @@ -1108,7 +1109,6 @@ def _process_image_input( results_cat = torch.concat(results) image_embeds = results_cat else: - pixel_values = image_input["pixel_values"].type(self.visual.dtype) image_embeds = self.visual(pixel_values, grid_thw=grid_thw) # Split concatenated embeddings for each image item. From de4e2c900f55701f50d813d1026755caae379d54 Mon Sep 17 00:00:00 2001 From: Sayantan Sarkar Date: Tue, 15 Apr 2025 01:33:23 +0000 Subject: [PATCH 08/38] Fix cu_seqlens_now --- vllm/model_executor/models/qwen2_5_vl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 2600e0af5e39..ced9e4eda859 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -280,6 +280,7 @@ def __init__( def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: # [s, b, 3 * head * head_dim] + breakpoint() seq_len, bs, _ = qkv.shape if self.tp_size > 1: qkv = tensor_model_parallel_all_gather(qkv) @@ -727,10 +728,9 @@ def forward( hidden_states = hidden_states.unsqueeze(1) for layer_num, blk in enumerate(self.blocks): if layer_num in self.fullatt_block_indexes: - #fullatt_block_attn_mask = None cu_seqlens_now = cu_seqlens + else: cu_seqlens_now = None - cu_seqlens_now = cu_window_seqlens hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens_now, rotary_pos_emb=rotary_pos_emb) From 35df595e0995cf60350f1a0ece9bc499bd528a7c Mon Sep 17 00:00:00 2001 From: Sayantan Sarkar Date: Tue, 15 Apr 2025 03:32:32 +0000 Subject: [PATCH 09/38] Remove pdb, fix shape --- vllm/model_executor/models/qwen2_5_vl.py | 2 +- vllm/worker/hpu_model_runner.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index ced9e4eda859..70b4a329c68a 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -725,12 +725,12 @@ def forward( cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0) hidden_states = x.unsqueeze(1) - hidden_states = hidden_states.unsqueeze(1) for layer_num, blk in enumerate(self.blocks): if layer_num in self.fullatt_block_indexes: cu_seqlens_now = cu_seqlens else: cu_seqlens_now = None + breakpoint() hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens_now, rotary_pos_emb=rotary_pos_emb) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 881704aa4017..d77e2a122ffa 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -2029,7 +2029,7 @@ def create_dummy_multi_modal_seq_group_metadata(self, group_id, seq_len, prompt_token_ids.extend([0] * (seq_len - len(prompt_token_ids))) seq_data = SequenceData.from_seqs(prompt_token_ids) - import pdb; pdb.set_trace() + #import pdb; pdb.set_trace() return SequenceGroupMetadata( request_id=str(group_id), is_prompt=True, From 37311b171b9e2f5f6679732858a329837565c480 Mon Sep 17 00:00:00 2001 From: Sayantan Sarkar Date: Tue, 15 Apr 2025 03:41:23 +0000 Subject: [PATCH 10/38] Remove breakpoints --- vllm/model_executor/models/qwen2_5_vl.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 70b4a329c68a..95253fde9ddd 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -280,7 +280,6 @@ def __init__( def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: # [s, b, 3 * head * head_dim] - breakpoint() seq_len, bs, _ = qkv.shape if self.tp_size > 1: qkv = tensor_model_parallel_all_gather(qkv) @@ -730,7 +729,6 @@ def forward( cu_seqlens_now = cu_seqlens else: cu_seqlens_now = None - breakpoint() hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens_now, rotary_pos_emb=rotary_pos_emb) From 10dfab9983b8ddd177d2b4fe488ea7e2217766e1 Mon Sep 17 00:00:00 2001 From: Gustavo Malkomes Date: Tue, 15 Apr 2025 18:51:29 +0000 Subject: [PATCH 11/38] using max_pixels during warmup split warmup in text only and image only force input_positions in text to be 3, seq_len --- vllm/worker/hpu_model_runner.py | 155 ++++++++++++++------------------ 1 file changed, 65 insertions(+), 90 deletions(-) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index d77e2a122ffa..6439481f1d6d 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -52,6 +52,7 @@ from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, MultiModalKwargs, MultiModalPlaceholderMap, MultiModalRegistry) +from vllm.multimodal.utils import cached_get_tokenizer from vllm.sampling_params import SamplingParams from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, Logprob, SequenceData, SequenceGroupMetadata, @@ -77,7 +78,7 @@ # Use caution when updating them! _PAD_SLOT_ID = 0 _PAD_BLOCK_ID = 0 - +_UNSET_MAX_PIXELS = 9999999 LORA_WARMUP_RANK = 8 VLLM_DELAYED_SAMPLING = os.environ.get('VLLM_DELAYED_SAMPLING', @@ -268,8 +269,13 @@ def __init__(self, model, vllm_config, layer_names): self.split_graph = self.model_is_mrope and os.getenv( 'VLLM_QWEN_SPLIT_GRAPHS', 'false').lower() in ['1', 'true'] + if not htorch.utils.internal.is_lazy() and self.split_graph: + logger.warning( + f"[Multimodal] HPU is not in Lazy Mode, " + f"split graph has not impact" + ) if htorch.utils.internal.is_lazy() and self.split_graph: - print("Split Graph to Visual and Language") + logger.info("[Multimodal] Split Graph to Visual and Language") self.model.visual = htorch.hpu.wrap_in_hpu_graph( self.model.visual, disable_tensor_cache=False) self.model.language_model.model = htorch.hpu.wrap_in_hpu_graph( @@ -1015,17 +1021,15 @@ def get_model(self) -> torch.nn.Module: return self.model.model return self.model - def _use_graphs(self, batch_size, seq_len, is_prompt, is_multimodal): + def _use_graphs(self, batch_size, seq_len, is_prompt, max_pixels=None): if self.enforce_eager: return False if self.skip_warmup: return True - if not is_multimodal or not self.graphed_multimodal_buckets: + if not max_pixels or not self.graphed_multimodal_buckets: return (batch_size, seq_len, is_prompt) in self.graphed_buckets else: - #TODO:For now return TRUE for development - #This needs to be updated later with proper bucket detections. - return True + return (batch_size, seq_len, is_prompt, max_pixels) in self.graphed_buckets def _is_valid_bucket(self, bucket): return bucket[0] * bucket[1] <= self.max_num_batched_tokens @@ -1989,16 +1993,8 @@ def trim_attn_metadata(self, metadata: AttentionMetadata) -> object: return attention_metadata def create_dummy_multi_modal_seq_group_metadata(self, group_id, seq_len, - lora_request, temperature, - max_pixels): - - from vllm.multimodal.utils import cached_get_tokenizer - - if self.is_pooler: - sampling_params = None - else: - sampling_params = SamplingParams(temperature=temperature) - + max_pixels, sampling_params, + lora_request): assert self.mm_registry.has_processor(self.model_config) tokenizer = cached_get_tokenizer( self.model_config.tokenizer, @@ -2013,23 +2009,21 @@ def create_dummy_multi_modal_seq_group_metadata(self, group_id, seq_len, seq_len=seq_len, mm_counts=mm_counts, ) - if max_pixels is not None: - # Note: We will overwrite this value if any exits - processor_inputs.hf_processor_mm_kwargs["max_pixels"] = max_pixels - print(" ============ ", processor_inputs.hf_processor_mm_kwargs) + if max_pixels and max_pixels != _UNSET_MAX_PIXELS: + hf_processor_mm_kwargs = dict(processor_inputs.hf_processor_mm_kwargs) + hf_processor_mm_kwargs["max_pixels"] = max_pixels + print(f" ===== {max_pixels} ====== : ", hf_processor_mm_kwargs) mm_inputs = processor.apply( prompt=processor_inputs.prompt_text, mm_data=processor_inputs.mm_data, - hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, ) prompt_token_ids = mm_inputs["prompt_token_ids"] placeholders_by_modality = mm_inputs["mm_placeholders"] - prompt_token_ids.extend([0] * (seq_len - len(prompt_token_ids))) seq_data = SequenceData.from_seqs(prompt_token_ids) - #import pdb; pdb.set_trace() return SequenceGroupMetadata( request_id=str(group_id), is_prompt=True, @@ -2046,6 +2040,7 @@ def create_dummy_seq_group_metadata(self, seq_len, is_prompt, lora_request=None, + max_pixels=None, temperature=0): if self.is_pooler: sampling_params = None @@ -2053,6 +2048,14 @@ def create_dummy_seq_group_metadata(self, sampling_params = SamplingParams(temperature=temperature) num_blocks = math.ceil(seq_len / self.block_size) seq_len = max(seq_len, 1) + if is_prompt and max_pixels: + return self.create_dummy_multi_modal_seq_group_metadata( + group_id=group_id, + seq_len=seq_len, + max_pixels=max_pixels, + sampling_params=sampling_params, + lora_request=lora_request, + ) if is_prompt: input_len = seq_len output_len = 0 @@ -2074,25 +2077,24 @@ def create_dummy_seq_group_metadata(self, lora_request=lora_request) def profile_run(self) -> None: + # TODO FIX PROFILE + return num_layers = self.model_config.get_num_layers(self.parallel_config) kv_caches = [None] * num_layers bind_kv_cache( self.vllm_config.compilation_config.static_forward_context, [kv_caches]) - # FIXME Going to set this to on big batch indepedent of bucketing_ctx - # _, max_seq_len = self.bucketing_ctx.get_max_prompt_shape() - # max_batch_size = min(self.max_num_seqs, - # self.max_num_batched_tokens // max_seq_len) - max_batch_size = 1 - max_seq_len = self.max_num_batched_tokens + _, max_seq_len = self.bucketing_ctx.get_max_prompt_shape() + max_batch_size = min(self.max_num_seqs, + self.max_num_batched_tokens // max_seq_len) self.warmup_scenario( - batch_size=max_batch_size, + batch_size=1, seq_len=max_seq_len, is_prompt=True, kv_caches=kv_caches, is_pt_profiler_run=False, + max_pixels=_UNSET_MAX_PIXELS, is_lora_profile_run=True, - multimodal_seqs_group_metada=True, ) return @@ -2103,17 +2105,17 @@ def warmup_scenario(self, kv_caches, is_pt_profiler_run=False, is_lora_profile_run=False, - multimodal_seqs_group_metada=False, temperature=0, max_pixels=None, return_time=False) -> None: - use_graphs = self._use_graphs(batch_size, seq_len, is_prompt, - multimodal_seqs_group_metada) + use_graphs = self._use_graphs(batch_size, seq_len, is_prompt, max_pixels) scenario_name = ("warmup_" f"{'prompt' if is_prompt else 'decode'}_" f"bs{batch_size}_" f"seq{seq_len}_" - f"graphs{'T' if use_graphs else 'F'}") + f"multimodal{max_pixels if max_pixels else 'F'}_" + f"graphs{'T' if use_graphs else 'F'}" + ) # This represents the maximum number of different requests # that will have unique loras, an therefore the max amount of memory # consumption create dummy lora request copies from the lora request @@ -2142,18 +2144,7 @@ def warmup_scenario(self, if return_time: times += 1 - if multimodal_seqs_group_metada: - seqs = [ - self.create_dummy_multi_modal_seq_group_metadata( - group_id=i, - seq_len=seq_len, - lora_request=dummy_lora_requests_per_seq[i] - if dummy_lora_requests_per_seq else None, - temperature=temperature, - max_pixels=max_pixels, - ) for i in range(batch_size) - ] - elif is_prompt: + if is_prompt: seqs = [ self.create_dummy_seq_group_metadata( i, @@ -2161,7 +2152,9 @@ def warmup_scenario(self, is_prompt, lora_request=dummy_lora_requests_per_seq[i] if dummy_lora_requests_per_seq else None, - temperature=temperature) for i in range(batch_size) + max_pixels=max_pixels, + temperature=temperature + ) for i in range(batch_size) ] else: # FIXME: seq_len is actually number of blocks @@ -2284,7 +2277,7 @@ def log_warmup_multimodal(self, phase, i, max_i, batch_size, seq_len, f"free_mem:{free_mem}") logger.info(msg) - def warmup_all_buckets(self, buckets, is_prompt, kv_caches): + def _warmup_multimodal(self, buckets, is_prompt, kv_caches): # TODO: The plan here is loop over a couple of image # resolutions and see if that helps during the warmup # somehow indepedent of the batch_size, seq_len @@ -2294,9 +2287,9 @@ def warmup_all_buckets(self, buckets, is_prompt, kv_caches): # Warmup Multimodal with fixed seq_len if not hasattr(self, 'visual_warmup_times'): self.visual_warmup_times = {} + _, max_seq_len = self.bucketing_ctx.get_max_prompt_shape() + max_batch_size = 1 # NOTE: Fix batch size to 1 for multimodal for i, max_pixels in enumerate(self.multimodal_buckets): - max_batch_size = 1 #TODO: For now we hardcoded batch 1. - max_seq_len = 2048 #TODO: set with VLLM_PROMPT_SEQ_BUCKET_MAX self.log_warmup_multimodal('Image', i, max_seq_len, max_batch_size, max_seq_len, max_pixels) t = self.warmup_scenario(batch_size=max_batch_size, @@ -2305,30 +2298,15 @@ def warmup_all_buckets(self, buckets, is_prompt, kv_caches): kv_caches=kv_caches, is_pt_profiler_run=False, is_lora_profile_run=True, - multimodal_seqs_group_metada=True, - max_pixels=max_pixels, + max_pixels=max_pixels * 14 * 14, return_time=True) - #if ((h*w) / (14*14)) in self.visual_warmup_times: - #breakpoint() - #print() - self.visual_warmup_times[((h*w) / (14*14))] = self.visual_warmup_times.get(((h*w) / (14*14)), []) + [('nograph', t)] # TODO hardcoded "14" remove. "14" is a model specific number, maybe (h,w) or h*w is a better key? - - #Warmup without multimodal for text-prompt only - #TODO: We might need to warmup with smaller multimodal to generate - #3D position tensor for multimodal model. + + def warmup_all_buckets(self, buckets, is_prompt, kv_caches): for i, (batch_size, seq_len) in enumerate(reversed(buckets)): self.log_warmup('Prompt' if is_prompt else 'Decode', i, len(buckets), batch_size, seq_len) - #self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches) - if is_prompt: - self.warmup_scenario(batch_size, - seq_len, - is_prompt, - kv_caches, - multimodal_seqs_group_metada=True, - max_pixels=max_pixels) # everythign must be 112 aligned (for now) - else: - self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches) + self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches) + self._warmup_multimodal(buckets, is_prompt, kv_caches) def warmup_graphs(self, strategy, @@ -2374,8 +2352,7 @@ def warmup_graphs(self, kv_caches, temperature=1.0 if batch_size not in warmed_random_sampler_bs else 0, - multimodal_seqs_group_metada=True if is_prompt else False, - max_pixels=14*14*112) + ) warmed_random_sampler_bs.add(batch_size) used_mem = align_workers(mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX) @@ -2385,29 +2362,27 @@ def warmup_graphs(self, # TODO: Multimodal HPU graph warmup need to be also check Memory, # and drop some buckets if memory is not sufficient. - print("WARMUP MULTIMODAL IMAGE GRAPH") - for idx, (h, w) in enumerate(self.multimodal_buckets): - graphed_multimodal_buckets = (1, h, w) - if graphed_multimodal_buckets in self.graphed_multimodal_buckets: - continue - self.graphed_multimodal_buckets.add(graphed_multimodal_buckets) - - for i, (b, h, w) in enumerate(self.graphed_multimodal_buckets): - max_batch_size = 1 #TODO: For now we hardcoded batch 1. - max_seq_len = 2048 #TODO: set with VLLM_PROMPT_SEQ_BUCKET_MAX (1680x1680 error on HPU GRAPH) + logger.info("[Multimodal] WARMUP IMAGE GRAPH") + for idx, max_pixels in enumerate(self.multimodal_buckets): + graphed_multimodal_buckets = (1, max_pixels) + if not graphed_multimodal_buckets in self.graphed_multimodal_buckets: + self.graphed_multimodal_buckets.add(graphed_multimodal_buckets) + + + for i, (b, max_pixels) in enumerate(self.graphed_multimodal_buckets): + max_batch_size = b #TODO: For now we hardcoded batch 1. + assert max_batch_size == 1 # The visual warmup does not need to run in batches + #max_seq_len = 2048 #TODO: set with VLLM_PROMPT_SEQ_BUCKET_MAX (1680x1680 error on HPU GRAPH) + _, max_seq_len = self.bucketing_ctx.get_max_prompt_shape() self.log_warmup_multimodal('Graph/Image', i, max_seq_len, - max_batch_size, max_seq_len, h, w) + max_batch_size, max_seq_len, max_pixels) t = self.warmup_scenario( batch_size=max_batch_size, seq_len=max_seq_len, is_prompt=True, kv_caches=kv_caches, - #is_pt_profiler_run=False, - #is_lora_profile_run=True, - multimodal_seqs_group_metada=True, - max_pixels=max_pixels, + max_pixels=max_pixels * 14 * 14, return_time=True) - self.visual_warmup_times[((h*w) / (14*14))] = self.visual_warmup_times.get(((h*w) / (14*14)), []) + [('graph', t)] return total_mem, total_batch_seq, captured_all From 2042673bb4148e1992d3ba0bfe746e240c5808ef Mon Sep 17 00:00:00 2001 From: Sayantan Sarkar Date: Wed, 16 Apr 2025 04:00:54 +0000 Subject: [PATCH 12/38] Video inputs ignored for now --- vllm/model_executor/models/qwen2_5_vl.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 95253fde9ddd..b6ce54bf21b9 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -1217,6 +1217,9 @@ def get_input_embeddings_v0( ) if video_input is not None: + if is_hpu: + print("Video inputs have not been enabled/verified yet, ignoring video inputs") + return inputs_embeds video_embeds = self._process_video_input(video_input) inputs_embeds = merge_multimodal_embeddings( input_ids, From cfb7809a549cbc999cd4c60698c123aad181f065 Mon Sep 17 00:00:00 2001 From: Sayantan Sarkar Date: Wed, 16 Apr 2025 04:02:03 +0000 Subject: [PATCH 13/38] Remove unused return_time --- vllm/worker/hpu_model_runner.py | 22 ++++++---------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 6439481f1d6d..de1fc5c2ccf6 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -463,7 +463,6 @@ def forward(self, *args, **kwargs): with compile_only_mode_context(): #calculate embedding for multimodal - #breakpoint() image_input = self.model._parse_and_validate_image_input( **kwargs) video_input = self.model._parse_and_validate_video_input( @@ -2106,8 +2105,7 @@ def warmup_scenario(self, is_pt_profiler_run=False, is_lora_profile_run=False, temperature=0, - max_pixels=None, - return_time=False) -> None: + max_pixels=None) -> None: use_graphs = self._use_graphs(batch_size, seq_len, is_prompt, max_pixels) scenario_name = ("warmup_" f"{'prompt' if is_prompt else 'decode'}_" @@ -2141,8 +2139,6 @@ def warmup_scenario(self, ] self.profiler.start('internal', scenario_name) times = 3 if use_graphs or is_pt_profiler_run else 1 - if return_time: - times += 1 if is_prompt: seqs = [ @@ -2176,8 +2172,6 @@ def warmup_scenario(self, profiler.start() for _ in range(times): inputs = self.prepare_model_input(seqs) - if return_time: - tstart = time.time() is_single_step = \ self.vllm_config.scheduler_config.num_scheduler_steps == 1 if is_prompt or is_single_step: @@ -2211,15 +2205,13 @@ def warmup_scenario(self, num_steps=2, seqs=seqs) torch.hpu.synchronize() - if return_time: - t_total = time.time() - tstart + if profiler: profiler.step() if profiler: profiler.stop() self.profiler.end() gc.collect() - return t_total if return_time else None def remove_all_loras(self): if not self.lora_manager: @@ -2292,14 +2284,13 @@ def _warmup_multimodal(self, buckets, is_prompt, kv_caches): for i, max_pixels in enumerate(self.multimodal_buckets): self.log_warmup_multimodal('Image', i, max_seq_len, max_batch_size, max_seq_len, max_pixels) - t = self.warmup_scenario(batch_size=max_batch_size, + self.warmup_scenario(batch_size=max_batch_size, seq_len=max_seq_len, is_prompt=True, kv_caches=kv_caches, is_pt_profiler_run=False, is_lora_profile_run=True, - max_pixels=max_pixels * 14 * 14, - return_time=True) + max_pixels=max_pixels * 14 * 14) def warmup_all_buckets(self, buckets, is_prompt, kv_caches): for i, (batch_size, seq_len) in enumerate(reversed(buckets)): @@ -2376,13 +2367,12 @@ def warmup_graphs(self, _, max_seq_len = self.bucketing_ctx.get_max_prompt_shape() self.log_warmup_multimodal('Graph/Image', i, max_seq_len, max_batch_size, max_seq_len, max_pixels) - t = self.warmup_scenario( + self.warmup_scenario( batch_size=max_batch_size, seq_len=max_seq_len, is_prompt=True, kv_caches=kv_caches, - max_pixels=max_pixels * 14 * 14, - return_time=True) + max_pixels=max_pixels * 14 * 14) return total_mem, total_batch_seq, captured_all From 3cce3bb4fd19f755dad0dc2174d0639fe8419ea8 Mon Sep 17 00:00:00 2001 From: Sayantan Sarkar Date: Wed, 16 Apr 2025 04:20:52 +0000 Subject: [PATCH 14/38] Add warning about 112 alignment --- vllm/model_executor/models/qwen2_5_vl.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index b6ce54bf21b9..a2834c7054e7 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -722,7 +722,8 @@ def forward( rotary_pos_emb: torch.Tensor) -> torch.Tensor: assert x.shape[0] == cu_seqlens[-1] == rotary_pos_emb.shape[0] cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0) - + if is_hpu: + assert x.shape[0]%64 == 0, "Expect inputs to be 112x112 aligned. Please align before sending image or use this version of transformer that does the resizing/alignment automatically: pip install git+https://github.com/malkomes/transformers.git@e4269f72aebb00b82cc232866e6565597f6ceacf" hidden_states = x.unsqueeze(1) for layer_num, blk in enumerate(self.blocks): if layer_num in self.fullatt_block_indexes: From 1c4d44c688c245c5adb6a329ce0dda2f00031e9c Mon Sep 17 00:00:00 2001 From: Sayantan Sarkar Date: Wed, 16 Apr 2025 05:21:34 +0000 Subject: [PATCH 15/38] Move VisionBuckets out to hpu model runner --- vllm/model_executor/models/qwen2_5_vl.py | 14 +------- vllm/worker/hpu_model_runner.py | 42 +++++++++++++++++++++--- 2 files changed, 39 insertions(+), 17 deletions(-) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index a2834c7054e7..bd5ac4cbea5d 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -901,12 +901,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) - envvar = os.environ.get('FIXED_MULTIMODAL_BUCKETS', "") - if envvar == "": - self.FIXED_MULTIMODAL_BUCKETS = [1600, 3200, 4800, 6400] # add 768 a small bucket maybe? - else: - self.FIXED_MULTIMODAL_BUCKETS = [int(i) for i in envvar.split(',')] - assert all([k%64 == 0 for k in self.FIXED_MULTIMODAL_BUCKETS]), f"FIXED_MULTIMODAL_BUCKETS should all be multiples of 64, but was {self.FIXED_MULTIMODAL_BUCKETS}" @cached_property def sampler(self): @@ -1012,18 +1006,12 @@ def _parse_and_validate_video_input( type="video_embeds", video_embeds=video_embeds, video_grid_thw=video_grid_thw) - def _get_multimodal_bucket(self, curr_num_image_patches): - for mm_bucket in self.FIXED_MULTIMODAL_BUCKETS: - if curr_num_image_patches <= mm_bucket: - return mm_bucket - self.FIXED_MULTIMODAL_BUCKETS += [curr_num_image_patches] # a shape larger than any that was compiled before. its gonna be compiled now, so save it for the future - return curr_num_image_patches def pad_multimodal_data(self, pixel_values, image_grid_thw): assert pixel_values.shape[ 0] % 64 == 0, '[testing version] needs 64 aligned resolution' - desired_number_of_pixels = self._get_multimodal_bucket(pixel_values.shape[0]) + desired_number_of_pixels = self.vision_buckets.get_multimodal_bucket(pixel_values.shape[0]) padding_len = desired_number_of_pixels - pixel_values.shape[0] if padding_len <= 0: #breakpoint() diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index de1fc5c2ccf6..f7291d591cae 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -88,6 +88,28 @@ DUMMY_TOKEN_ID = -1 +''' +This class is used to bucket image tokens +''' +class VisionBuckets(): + def __init__(self): + envvar = os.environ.get('VLLM_MULTIMODAL_BUCKETS', "") + if envvar == "": + self.multimodal_buckets = [1600, 3200, 4800, 6400] + else: + self.multimodal_buckets = [int(i) for i in envvar.split(',')] + + def get_multimodal_bucket(self, curr_num_image_patches): + for mm_bucket in self.multimodal_buckets: + if curr_num_image_patches <= mm_bucket: + return mm_bucket + self.multimodal_buckets += [curr_num_image_patches] # a shape larger than any that was compiled before. its gonna be compiled now, so save it for the future + return curr_num_image_patches + + def __repr__(self): + return str(self.multimodal_buckets) + + class PhaseType(Enum): PREFILL = 'prefill' PREFIX_PREFILL = 'prefix_prefill' @@ -762,7 +784,7 @@ def __init__( self.use_merged_prefill, self.max_model_len) self.graphed_buckets: Set[Any] = set() - self.multimodal_buckets = [] #This should be use HPUBucketingContext + self.multimodal_buckets = [] #This should be use HPUBucketingContext << self.graphed_multimodal_buckets: Set[Any] = set() self._set_gc_threshold() @@ -923,6 +945,8 @@ def load_model(self) -> None: self.model_memory_usage = m.consumed_device_memory msg = f"Loading model weights took in total {m.get_summary_string()}" logger.info(msg) + self.add_vision_buckets_to_model() + def _add_dummy_seq(self, seq_group_metadata_list, is_prompt): real_batch_size = len(seq_group_metadata_list) @@ -1120,6 +1144,16 @@ def make_attn_bias(self, seq_lens, max_prompt_len, dtype): def move_to_device(self, tensor): return tensor if tensor is None else tensor.to(self.device, non_blocking=True) + ''' + Right now Qwen2.5VL needs to know these buckets so it can do some things internally + ''' + def add_vision_buckets_to_model(self): + if supports_multimodal(self.get_model()): + vb = VisionBuckets() + if isinstance(self.model, HpuModelAdapter): + self.model.model.vision_buckets = vb + else: + self.model.vision_buckets = vb def _prepare_prompt( self, @@ -2011,7 +2045,7 @@ def create_dummy_multi_modal_seq_group_metadata(self, group_id, seq_len, if max_pixels and max_pixels != _UNSET_MAX_PIXELS: hf_processor_mm_kwargs = dict(processor_inputs.hf_processor_mm_kwargs) hf_processor_mm_kwargs["max_pixels"] = max_pixels - print(f" ===== {max_pixels} ====== : ", hf_processor_mm_kwargs) + print(f" ===== {max_pixels} ====== : ", hf_processor_mm_kwargs) # logger.info mm_inputs = processor.apply( prompt=processor_inputs.prompt_text, mm_data=processor_inputs.mm_data, @@ -2396,8 +2430,8 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: self.bucketing_ctx.generate_decode_buckets(max_blocks) if supports_multimodal(self.model.model): - self.multimodal_buckets = self.model.model.FIXED_MULTIMODAL_BUCKETS - print("Multimodal bucket :", self.multimodal_buckets) + self.multimodal_buckets = self.get_model().vision_buckets.multimodal_buckets + logger.info(f"Multimodal bucket : {self.multimodal_buckets}") if profile := os.environ.get('VLLM_PT_PROFILE', None): phase, bs, seq_len, graph = profile.split('_') From 5bbfdebe3ee8dc8eeec1ef84f0c02c31520ba4eb Mon Sep 17 00:00:00 2001 From: Jimin Ha Date: Wed, 16 Apr 2025 21:50:29 +0000 Subject: [PATCH 16/38] Create full attention mask outside of VisionTransformer full_attention_mask doesn't need to be created for each full attention layer, only create once and reuse. This can save memory and time. --- vllm/model_executor/models/qwen2_5_vl.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index bd5ac4cbea5d..70bf8899e932 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -305,6 +305,7 @@ def forward( self, x: torch.Tensor, cu_seqlens: Optional[torch.Tensor], + fullattn_mask: Optional[torch.Tensor], rotary_pos_emb: torch.Tensor, ) -> torch.Tensor: # [s, b, c] --> [s, b, head * 3 * head_dim] @@ -370,7 +371,7 @@ def forward( outputs.append(output_i) context_layer = torch.cat(outputs, dim=1) else: - fullatt_block_attn_mask = create_block_diagonal_attention_mask_outerprod(cu_seqlens) + fullatt_block_attn_mask = fullattn_mask q1, k1, v1 = (rearrange(x, "b s h d -> b h s d")for x in [q, k, v]) @@ -431,11 +432,13 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.mlp") - def forward(self, x: torch.Tensor, #cu_seqlens: torch.Tensor, - cu_seqlens: Optional[torch.Tensor], + def forward(self, x: torch.Tensor, + cu_seqlens: torch.Tensor, + fullattn_mask: Optional[torch.Tensor], rotary_pos_emb: torch.Tensor) -> torch.Tensor: x = x + self.attn(self.norm1(x), cu_seqlens=cu_seqlens, + fullattn_mask=fullattn_mask, rotary_pos_emb=rotary_pos_emb) x = x + self.mlp(self.norm2(x)) return x @@ -713,25 +716,28 @@ def remove_duplicates_cpu(a): cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( dim=0, dtype=torch.int32) + cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0) return hidden_states, rotary_pos_emb, cu_seqlens, cu_window_seqlens, window_index def forward( self, x: torch.Tensor, cu_seqlens: torch.Tensor, + fullattn_mask: Optional[torch.Tensor], rotary_pos_emb: torch.Tensor) -> torch.Tensor: - assert x.shape[0] == cu_seqlens[-1] == rotary_pos_emb.shape[0] - cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0) if is_hpu: assert x.shape[0]%64 == 0, "Expect inputs to be 112x112 aligned. Please align before sending image or use this version of transformer that does the resizing/alignment automatically: pip install git+https://github.com/malkomes/transformers.git@e4269f72aebb00b82cc232866e6565597f6ceacf" hidden_states = x.unsqueeze(1) for layer_num, blk in enumerate(self.blocks): + #TODO: now we premake fullattn_mask, we don't need to pass cu_seqlens + #but keep it here for now since other ATTN is using this argument. Need to clean code. if layer_num in self.fullatt_block_indexes: cu_seqlens_now = cu_seqlens else: cu_seqlens_now = None hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens_now, + fullattn_mask=fullattn_mask, rotary_pos_emb=rotary_pos_emb) # adapter @@ -1086,10 +1092,15 @@ def _process_image_input( assert pixel_values.shape[0] % 64 == 0, f"We need image h/w to be aligned to 112 for now. Which will make pixel_values be a multiple of (112/14)*(112/14)=64 (14 is patch size for ViT). Got pixel_values shape {pixel_values.shape[0]}" expanded_cu_seqlens = expand_to_max(cu_seqlens, 3) # either a single image, or a single image and its accompanying pad image, so only max expansion to 3 + #Create full attention block mast before VisionTransformer to save memory/time + #TODO cu_seqlens can be removed but keep it here for now + fullatt_block_attn_mask = create_block_diagonal_attention_mask_outerprod(cu_seqlens) + assert pixel_values_curr_img_padded.shape[0] == expanded_cu_seqlens[-1] == rot_pos_emb.shape[0] htcore.mark_step() hidden_states = self.visual(pixel_values_curr_img_padded, rotary_pos_emb=rot_pos_emb, - cu_seqlens=expanded_cu_seqlens,) + cu_seqlens=expanded_cu_seqlens, + fullattn_mask=fullatt_block_attn_mask,) htcore.mark_step() image_embeds = self.visual.post_attn(hidden_states, window_index) results += [image_embeds[:img_shape_padded[0].prod()//4, :]] # slice image_embeds to remove the padded parts. instead of hardcoding 4, maybe use config spatial merge etc From 151b3e3ef4fad74c904cf77a349331ff9e217c0e Mon Sep 17 00:00:00 2001 From: Gustavo Malkomes Date: Wed, 16 Apr 2025 23:08:22 +0000 Subject: [PATCH 17/38] warmup multimoda graph with memory track? --- vllm/worker/hpu_model_runner.py | 145 ++++++++++++++++++++------------ 1 file changed, 89 insertions(+), 56 deletions(-) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index f7291d591cae..fdb485d2a9d8 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -1148,12 +1148,9 @@ def move_to_device(self, tensor): Right now Qwen2.5VL needs to know these buckets so it can do some things internally ''' def add_vision_buckets_to_model(self): - if supports_multimodal(self.get_model()): - vb = VisionBuckets() - if isinstance(self.model, HpuModelAdapter): - self.model.model.vision_buckets = vb - else: - self.model.vision_buckets = vb + model = self.get_model() + if supports_multimodal(model): + model.vision_buckets = VisionBuckets() def _prepare_prompt( self, @@ -2028,7 +2025,7 @@ def trim_attn_metadata(self, metadata: AttentionMetadata) -> object: def create_dummy_multi_modal_seq_group_metadata(self, group_id, seq_len, max_pixels, sampling_params, lora_request): - assert self.mm_registry.has_processor(self.model_config) + assert self.mm_registry.has_processor(self.model_config), 'Multimodal Warmup needs a processor' tokenizer = cached_get_tokenizer( self.model_config.tokenizer, trust_remote_code=self.model_config.trust_remote_code, @@ -2042,10 +2039,11 @@ def create_dummy_multi_modal_seq_group_metadata(self, group_id, seq_len, seq_len=seq_len, mm_counts=mm_counts, ) + + hf_processor_mm_kwargs = dict(processor_inputs.hf_processor_mm_kwargs) if max_pixels and max_pixels != _UNSET_MAX_PIXELS: - hf_processor_mm_kwargs = dict(processor_inputs.hf_processor_mm_kwargs) hf_processor_mm_kwargs["max_pixels"] = max_pixels - print(f" ===== {max_pixels} ====== : ", hf_processor_mm_kwargs) # logger.info + mm_inputs = processor.apply( prompt=processor_inputs.prompt_text, mm_data=processor_inputs.mm_data, @@ -2054,7 +2052,9 @@ def create_dummy_multi_modal_seq_group_metadata(self, group_id, seq_len, prompt_token_ids = mm_inputs["prompt_token_ids"] placeholders_by_modality = mm_inputs["mm_placeholders"] - prompt_token_ids.extend([0] * (seq_len - len(prompt_token_ids))) + num_tokens_to_extend = seq_len - len(prompt_token_ids) + assert num_tokens_to_extend > 0, "seq_len is smaller than multimodal tokens" + prompt_token_ids.extend([0] * (num_tokens_to_extend)) seq_data = SequenceData.from_seqs(prompt_token_ids) return SequenceGroupMetadata( @@ -2294,32 +2294,28 @@ def log_warmup_multimodal(self, phase, i, max_i, batch_size, seq_len, max_pixels): free_mem = format_bytes( HabanaMemoryProfiler.current_free_device_memory()) - dim = "num_blocks" - if "Prompt" in phase: - dim = "seq_len" + dim = "seq_len" msg = (f"[Warmup][{phase}][{i+1}/{max_i}] " f"batch_size:{batch_size} " - f"{dim}:{seq_len}", f"max_pixels:{max_pixels}", + f"{dim}:{seq_len} " + f"max_pixels:{max_pixels} " f"free_mem:{free_mem}") logger.info(msg) - def _warmup_multimodal(self, buckets, is_prompt, kv_caches): - # TODO: The plan here is loop over a couple of image - # resolutions and see if that helps during the warmup - # somehow indepedent of the batch_size, seq_len - # might need to mark.step() somewhere to split the - # HPU graph for video and language model - - # Warmup Multimodal with fixed seq_len - if not hasattr(self, 'visual_warmup_times'): - self.visual_warmup_times = {} + def _warmup_multimodal(self, kv_caches): + if not supports_multimodal(self.get_model()): + return _, max_seq_len = self.bucketing_ctx.get_max_prompt_shape() - max_batch_size = 1 # NOTE: Fix batch size to 1 for multimodal + seq_len = max_seq_len + batch_size = 1 + phase = 'Multimodal' + num_candidates = len(self.multimodal_buckets) for i, max_pixels in enumerate(self.multimodal_buckets): - self.log_warmup_multimodal('Image', i, max_seq_len, max_batch_size, - max_seq_len, max_pixels) - self.warmup_scenario(batch_size=max_batch_size, - seq_len=max_seq_len, + self.log_warmup_multimodal(phase, i, num_candidates, + batch_size, seq_len, + max_pixels) + self.warmup_scenario(batch_size=batch_size, + seq_len=seq_len, is_prompt=True, kv_caches=kv_caches, is_pt_profiler_run=False, @@ -2331,7 +2327,7 @@ def warmup_all_buckets(self, buckets, is_prompt, kv_caches): self.log_warmup('Prompt' if is_prompt else 'Decode', i, len(buckets), batch_size, seq_len) self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches) - self._warmup_multimodal(buckets, is_prompt, kv_caches) + self._warmup_multimodal(kv_caches) def warmup_graphs(self, strategy, @@ -2385,28 +2381,60 @@ def warmup_graphs(self, total_mem += used_mem total_batch_seq += batch_seq - # TODO: Multimodal HPU graph warmup need to be also check Memory, - # and drop some buckets if memory is not sufficient. - logger.info("[Multimodal] WARMUP IMAGE GRAPH") + mm_outputs = \ + self._warmup_multimodal_graph( + kv_caches=kv_caches, + available_mem=available_mem, + starting_mem=total_mem, + total_batch_seq=total_batch_seq, + ) + if mm_outputs is not None: + total_mem, total_batch_seq, mm_captured_all = mm_outputs + captured_all = captured_all and mm_captured_all + return total_mem, total_batch_seq, captured_all + + def _warmup_multimodal_graph(self, + kv_caches, + available_mem, + starting_mem=0, + total_batch_seq=0.001): + if not supports_multimodal(self.get_model()): + return None + total_mem = starting_mem + idx = 0 + phase = f'Graph/Multimodal' + num_candidates = len(self.multimodal_buckets) + captured_all = True for idx, max_pixels in enumerate(self.multimodal_buckets): - graphed_multimodal_buckets = (1, max_pixels) - if not graphed_multimodal_buckets in self.graphed_multimodal_buckets: - self.graphed_multimodal_buckets.add(graphed_multimodal_buckets) + batch_size = 1 # Note: Multimodal buckets are indepedent of batch_size + _, max_seq_len = self.bucketing_ctx.get_max_prompt_shape() + seq_len = max_seq_len + batch_seq = 1 * max_pixels + # Graph memory usage is proportional to seq dimension in a batch + mem_estimate = batch_seq / total_batch_seq * total_mem + if mem_estimate >= available_mem: + captured_all = False + continue + graphed_multimodal_bucket = max_pixels + if graphed_multimodal_bucket in self.graphed_multimodal_buckets: + continue + self.graphed_multimodal_buckets.add(graphed_multimodal_bucket) + self.log_warmup_multimodal(phase, idx, num_candidates, + batch_size, seq_len, max_pixels) + with HabanaMemoryProfiler() as mem_prof: + self.warmup_scenario( + batch_size=batch_size, + seq_len=seq_len, + is_prompt=True, + kv_caches=kv_caches, + max_pixels=max_pixels * 14 * 14) - for i, (b, max_pixels) in enumerate(self.graphed_multimodal_buckets): - max_batch_size = b #TODO: For now we hardcoded batch 1. - assert max_batch_size == 1 # The visual warmup does not need to run in batches - #max_seq_len = 2048 #TODO: set with VLLM_PROMPT_SEQ_BUCKET_MAX (1680x1680 error on HPU GRAPH) - _, max_seq_len = self.bucketing_ctx.get_max_prompt_shape() - self.log_warmup_multimodal('Graph/Image', i, max_seq_len, - max_batch_size, max_seq_len, max_pixels) - self.warmup_scenario( - batch_size=max_batch_size, - seq_len=max_seq_len, - is_prompt=True, - kv_caches=kv_caches, - max_pixels=max_pixels * 14 * 14) + used_mem = align_workers(mem_prof.consumed_device_memory, + torch.distributed.ReduceOp.MAX) + available_mem -= used_mem + total_mem += used_mem + total_batch_seq += batch_seq return total_mem, total_batch_seq, captured_all @@ -2421,7 +2449,15 @@ def log_graph_warmup_summary(self, buckets, is_prompt, total_mem): f'({100 * len(graphed) / num_candidates:.1f}%) ' f'used_mem:{format_bytes(total_mem)} ' f'buckets:{sorted(list(graphed))}') - logger.info(msg) + logger.info(msg) + if "Prompt" in phase and len(self.multimodal_buckets) > 0: + phase = "Graph/Multimodal" + num_candidates = len(self.multimodal_buckets) + mm_graphed = self.graphed_multimodal_buckets + msg = (f'{phase} captured:{len(mm_graphed)} ' + f'({100 * len(mm_graphed) / num_candidates:.1f}%) ' + f'buckets:{sorted(list(mm_graphed))}') + logger.info(msg) @torch.inference_mode() def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: @@ -2429,8 +2465,9 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: max_blocks = kv_caches[0][0].size(0) self.bucketing_ctx.generate_decode_buckets(max_blocks) - if supports_multimodal(self.model.model): - self.multimodal_buckets = self.get_model().vision_buckets.multimodal_buckets + model = self.get_model() + if supports_multimodal(model): + self.multimodal_buckets = model.vision_buckets.multimodal_buckets logger.info(f"Multimodal bucket : {self.multimodal_buckets}") if profile := os.environ.get('VLLM_PT_PROFILE', None): @@ -2588,10 +2625,6 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: f"Warmup finished in {elapsed_time:.0f} secs, " f"allocated {format_bytes(end_mem - start_mem)} of device memory") logger.info(msg) - # if hasattr(self, 'visual_warmup_times'): - # summary = {k: min([t for _, t in self.visual_warmup_times[k]]) for k in self.visual_warmup_times} - # self.visual_warmup_times = summary - # self.model.visual_warmup_times = self.visual_warmup_times self.profiler.end() def finish_measurements(self): From cd4a9de2549117c67dd46cb4730996e2ad53061e Mon Sep 17 00:00:00 2001 From: Jimin Ha Date: Thu, 17 Apr 2025 02:26:30 +0000 Subject: [PATCH 18/38] Enable profile_run and set disable_tensor_cache=True profile_run takes maximum tensor size of 65K. To support it, we need to reduce significant memory usage by adding below. - Set disable_tensor_cache=True for vision model as well - Add additional mark_step to split the graphs - Move einsum operation to cpu for bigger tensor(due to GC error) - Run FusedSDPA for longer sequence as well --- vllm/model_executor/models/qwen2_5_vl.py | 20 ++++++++++++++++++-- vllm/worker/hpu_model_runner.py | 9 ++++----- 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 70bf8899e932..9997b0cf56cf 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -112,7 +112,16 @@ def create_block_diagonal_attention_mask_outerprod(indices): zz = range_to_max_for_each_img >= indices[:-1].unsqueeze(1) xx = torch.logical_and(yy, zz) # can reduce sum externally or as batchmatmul - res = torch.sum(torch.einsum('bi,bj->bij', xx, xx), dim=0) + #TODO: einsum with tensor dimension too big doesn't work. Register max size error. + #We can always move to CPU for all einsum without shape checking if perf impact is minimal. + if xx.shape[-1] > 40000: + print("einsum running on CPU : ", xx.shape) + xx = xx.to("cpu") + res = torch.einsum('bi,bj->bij', xx, xx) + res = res.to("hpu") + res = torch.sum(res, dim=0) + else: + res = torch.sum(torch.einsum('bi,bj->bij', xx, xx), dim=0) #res = torch.einsum('bi,bj->ij', xx.float(), xx.float()) return res.bool() @@ -353,6 +362,10 @@ def forward( outputs = [] cu_seqlens = list(range(0, x.shape[0]+1, 64)) # assuming x%64=0 (image is 112 aligned in both h/w dims) for i in range(1, len(cu_seqlens)): + #TODO: Check if number 100 is good + #For large image, we add mark step here for every 100th step to make compile time shorter + if i % 100 == 0: + htcore.mark_step() start_idx = cu_seqlens[i - 1] end_idx = cu_seqlens[i] q_i = q[:, start_idx:end_idx] @@ -381,7 +394,9 @@ def forward( attn_mask = fullatt_block_attn_mask.reshape(batch_size, 1, seq_len_N_t, seq_len_N_s, -1)[:, :, :, :, 0] assert attn_mask.shape == mask_shape - if q1.shape[2] <= 6400: # this crossover point should be measured + #TODO:after 1by1 branch, even with long sequence, FusedSDPA is much faster + # Setting the number here to the max number we get in profile_run. + if q1.shape[2] <= 65536: # this crossover point should be measured fused_out = FusedSDPA.apply(q1, k1, v1, attn_mask, 0.0) # Bx1xNxN else: fused_out = AttentionLongSequence.forward(q1, k1, v1, attn_mask, 64) @@ -729,6 +744,7 @@ def forward( assert x.shape[0]%64 == 0, "Expect inputs to be 112x112 aligned. Please align before sending image or use this version of transformer that does the resizing/alignment automatically: pip install git+https://github.com/malkomes/transformers.git@e4269f72aebb00b82cc232866e6565597f6ceacf" hidden_states = x.unsqueeze(1) for layer_num, blk in enumerate(self.blocks): + htcore.mark_step() #TODO: now we premake fullattn_mask, we don't need to pass cu_seqlens #but keep it here for now since other ATTN is using this argument. Need to clean code. if layer_num in self.fullatt_block_indexes: diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index fdb485d2a9d8..3a00d2fb6721 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -95,7 +95,9 @@ class VisionBuckets(): def __init__(self): envvar = os.environ.get('VLLM_MULTIMODAL_BUCKETS', "") if envvar == "": - self.multimodal_buckets = [1600, 3200, 4800, 6400] + #TODO:with profile_run, the bucket of 65536 is added, so the pixel values + #bigger than 12800 below always padded to 65536 which is too big. + self.multimodal_buckets = [1600, 3200, 4800, 6400, 9600, 12800] else: self.multimodal_buckets = [int(i) for i in envvar.split(',')] @@ -299,7 +301,7 @@ def __init__(self, model, vllm_config, layer_names): if htorch.utils.internal.is_lazy() and self.split_graph: logger.info("[Multimodal] Split Graph to Visual and Language") self.model.visual = htorch.hpu.wrap_in_hpu_graph( - self.model.visual, disable_tensor_cache=False) + self.model.visual, disable_tensor_cache=True) self.model.language_model.model = htorch.hpu.wrap_in_hpu_graph( self.model.language_model.model, disable_tensor_cache=True) @@ -2053,7 +2055,6 @@ def create_dummy_multi_modal_seq_group_metadata(self, group_id, seq_len, prompt_token_ids = mm_inputs["prompt_token_ids"] placeholders_by_modality = mm_inputs["mm_placeholders"] num_tokens_to_extend = seq_len - len(prompt_token_ids) - assert num_tokens_to_extend > 0, "seq_len is smaller than multimodal tokens" prompt_token_ids.extend([0] * (num_tokens_to_extend)) seq_data = SequenceData.from_seqs(prompt_token_ids) @@ -2110,8 +2111,6 @@ def create_dummy_seq_group_metadata(self, lora_request=lora_request) def profile_run(self) -> None: - # TODO FIX PROFILE - return num_layers = self.model_config.get_num_layers(self.parallel_config) kv_caches = [None] * num_layers bind_kv_cache( From 260e7247bb235102087443e52aa5c830af76b4b9 Mon Sep 17 00:00:00 2001 From: Gustavo Malkomes Date: Thu, 17 Apr 2025 21:40:56 +0000 Subject: [PATCH 19/38] we dont need this anymore --- vllm/model_executor/models/utils.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index e55812d1a2a2..e9598077a379 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -366,14 +366,13 @@ def _merge_multimodal_embeddings( assert isinstance(num_expected_tokens, int) flattened = _flatten_embeddings(multimodal_embeddings) - # if flattened.shape[0] != num_expected_tokens: - # expr = _embedding_count_expression(multimodal_embeddings) - # raise ValueError( - # f"Attempted to assign {expr} = {flattened.shape[0]} " - # f"multimodal tokens to {num_expected_tokens} placeholders") - - # flattened could have dummy data from the padding after num_expected_tokens - inputs_embeds[is_multimodal] = flattened[:num_expected_tokens, :] + if flattened.shape[0] != num_expected_tokens: + expr = _embedding_count_expression(multimodal_embeddings) + raise ValueError( + f"Attempted to assign {expr} = {flattened.shape[0]} " + f"multimodal tokens to {num_expected_tokens} placeholders") + + inputs_embeds[is_multimodal] = flattened return inputs_embeds From 769cf6f7cf065082a6e243b391b7c993e069753e Mon Sep 17 00:00:00 2001 From: Sayantan Sarkar Date: Fri, 18 Apr 2025 02:30:15 +0000 Subject: [PATCH 20/38] we dont need b dim --- vllm/model_executor/models/qwen2_5_vl.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 9997b0cf56cf..edca70c906f3 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -110,18 +110,21 @@ def create_block_diagonal_attention_mask_outerprod(indices): range_to_max_for_each_img = torch.arange(maxsize, device=indices.device).unsqueeze(0).repeat(indices.shape[0]-1,1) yy = range_to_max_for_each_img < indices[1:].unsqueeze(1) zz = range_to_max_for_each_img >= indices[:-1].unsqueeze(1) - xx = torch.logical_and(yy, zz) + xx = torch.logical_and(yy, zz).float() # can reduce sum externally or as batchmatmul #TODO: einsum with tensor dimension too big doesn't work. Register max size error. #We can always move to CPU for all einsum without shape checking if perf impact is minimal. if xx.shape[-1] > 40000: print("einsum running on CPU : ", xx.shape) xx = xx.to("cpu") - res = torch.einsum('bi,bj->bij', xx, xx) + res = torch.einsum('bi,bj->ij', xx, xx) + #breakpoint() res = res.to("hpu") - res = torch.sum(res, dim=0) + #res = torch.sum(res, dim=0) else: - res = torch.sum(torch.einsum('bi,bj->bij', xx, xx), dim=0) + #res = torch.sum(torch.einsum('bi,bj->bij', xx, xx), dim=0) + res = torch.einsum('bi,bj->ij', xx, xx) + #print('.....MASK SHAPE', res.shape, indices) #res = torch.einsum('bi,bj->ij', xx.float(), xx.float()) return res.bool() @@ -396,6 +399,7 @@ def forward( #TODO:after 1by1 branch, even with long sequence, FusedSDPA is much faster # Setting the number here to the max number we get in profile_run. + #print('BATCHSIZE:: ', batch_size, '.........................................') if q1.shape[2] <= 65536: # this crossover point should be measured fused_out = FusedSDPA.apply(q1, k1, v1, attn_mask, 0.0) # Bx1xNxN else: From 79f65e06ba687bcdd3c8d98176084696ff526088 Mon Sep 17 00:00:00 2001 From: Jimin Ha Date: Fri, 18 Apr 2025 18:58:01 +0000 Subject: [PATCH 21/38] Fix use_graph to return correctly for multimodal buckets - fix use_graph to detect multimodal bucket correctly - pass the right pixel size for execution - change multimodal buckets to align with resize - remove multimodal warmup for Decode --- vllm/model_executor/models/qwen2_5_vl.py | 1 - vllm/worker/hpu_model_runner.py | 30 +++++++++++++++++------- 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index edca70c906f3..8cf898022d82 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -1225,7 +1225,6 @@ def get_input_embeddings_v0( image_input: Optional[tuple[torch.Tensor, ...]] = None, video_input: Optional[tuple[torch.Tensor, ...]] = None, ) -> torch.Tensor: - inputs_embeds = self.get_input_embeddings(input_ids) if image_input is not None: image_embeds = self._process_image_input(image_input) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 3a00d2fb6721..3c09c0a1ab4a 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -97,7 +97,8 @@ def __init__(self): if envvar == "": #TODO:with profile_run, the bucket of 65536 is added, so the pixel values #bigger than 12800 below always padded to 65536 which is too big. - self.multimodal_buckets = [1600, 3200, 4800, 6400, 9600, 12800] + #self.multimodal_buckets = [1600, 3136, 4096, 6400, 7744, 9216, 12544, 16384, 26500, 40000, 65536] + self.multimodal_buckets = [1600, 3136, 4096, 6400, 7744, 9216, 12544] else: self.multimodal_buckets = [int(i) for i in envvar.split(',')] @@ -105,7 +106,8 @@ def get_multimodal_bucket(self, curr_num_image_patches): for mm_bucket in self.multimodal_buckets: if curr_num_image_patches <= mm_bucket: return mm_bucket - self.multimodal_buckets += [curr_num_image_patches] # a shape larger than any that was compiled before. its gonna be compiled now, so save it for the future + #Remove dynamic bucket expands since this is not done for the language model. + #self.multimodal_buckets += [curr_num_image_patches] # a shape larger than any that was compiled before. its gonna be compiled now, so save it for the future return curr_num_image_patches def __repr__(self): @@ -1051,10 +1053,12 @@ def _use_graphs(self, batch_size, seq_len, is_prompt, max_pixels=None): return False if self.skip_warmup: return True - if not max_pixels or not self.graphed_multimodal_buckets: + if not max_pixels: return (batch_size, seq_len, is_prompt) in self.graphed_buckets else: - return (batch_size, seq_len, is_prompt, max_pixels) in self.graphed_buckets + #TODO: We might need to check both language bucket and multimodal bucket + # and return True only it's avialble, or return seperately. + return (max_pixels) in self.graphed_multimodal_buckets def _is_valid_bucket(self, bucket): return bucket[0] * bucket[1] <= self.max_num_batched_tokens @@ -1430,6 +1434,7 @@ def _prepare_prompt( enable_kv_scales_calculation=False, ) multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) + for t in multi_modal_kwargs: if torch.is_tensor(multi_modal_kwargs[t]): multi_modal_kwargs[t] = multi_modal_kwargs[t].to( @@ -2119,8 +2124,9 @@ def profile_run(self) -> None: _, max_seq_len = self.bucketing_ctx.get_max_prompt_shape() max_batch_size = min(self.max_num_seqs, self.max_num_batched_tokens // max_seq_len) + max_batch_size = max_batch_size if not self.model_is_mrope else 1 self.warmup_scenario( - batch_size=1, + batch_size=max_batch_size, seq_len=max_seq_len, is_prompt=True, kv_caches=kv_caches, @@ -2309,6 +2315,7 @@ def _warmup_multimodal(self, kv_caches): batch_size = 1 phase = 'Multimodal' num_candidates = len(self.multimodal_buckets) + for i, max_pixels in enumerate(self.multimodal_buckets): self.log_warmup_multimodal(phase, i, num_candidates, batch_size, seq_len, @@ -2326,7 +2333,8 @@ def warmup_all_buckets(self, buckets, is_prompt, kv_caches): self.log_warmup('Prompt' if is_prompt else 'Decode', i, len(buckets), batch_size, seq_len) self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches) - self._warmup_multimodal(kv_caches) + if is_prompt: + self._warmup_multimodal(kv_caches) def warmup_graphs(self, strategy, @@ -2893,8 +2901,14 @@ def execute_model( assert is_prompt is not None batch_size = input_tokens.size(0) seq_len = self._seq_len(attn_metadata) - use_graphs = self._use_graphs(batch_size, seq_len, is_prompt, - self.model_is_mrope) + if self.model_is_mrope and is_prompt and hasattr(model_input, 'multi_modal_kwargs') and \ + model_input.multi_modal_kwargs is not None and 'pixel_values' in model_input.multi_modal_kwargs: + max_pixel = model_input.multi_modal_kwargs['pixel_values'].shape[-2] + #print(f"max_pixel: {max_pixel}, count:{(model_input.input_tokens == 151655).sum().item()}, grid_thw: {model_input.multi_modal_kwargs['image_grid_thw']}") + else: + max_pixel = None + + use_graphs = self._use_graphs(batch_size, seq_len, is_prompt, max_pixels=max_pixel) self._check_config(batch_size, seq_len, attn_metadata, warmup_mode) lora_mask: torch.Tensor = None From cf4920377a8ec25dd7b39c0e2f123dc7498a8842 Mon Sep 17 00:00:00 2001 From: Gustavo Malkomes Date: Fri, 18 Apr 2025 23:56:27 +0000 Subject: [PATCH 22/38] sort vision buckets check if max bucket size of incoming images is graphed --- vllm/worker/hpu_model_runner.py | 57 +++++++++++++++++++++++---------- 1 file changed, 40 insertions(+), 17 deletions(-) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 3c09c0a1ab4a..0a35c9db7b8d 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -98,9 +98,10 @@ def __init__(self): #TODO:with profile_run, the bucket of 65536 is added, so the pixel values #bigger than 12800 below always padded to 65536 which is too big. #self.multimodal_buckets = [1600, 3136, 4096, 6400, 7744, 9216, 12544, 16384, 26500, 40000, 65536] - self.multimodal_buckets = [1600, 3136, 4096, 6400, 7744, 9216, 12544] + multimodal_buckets = [1600, 3136, 4096, 6400, 7744, 9216, 12544] else: - self.multimodal_buckets = [int(i) for i in envvar.split(',')] + multimodal_buckets = [int(i) for i in envvar.split(',')] + self.multimodal_buckets = sorted(multimodal_buckets) def get_multimodal_bucket(self, curr_num_image_patches): for mm_bucket in self.multimodal_buckets: @@ -1055,10 +1056,9 @@ def _use_graphs(self, batch_size, seq_len, is_prompt, max_pixels=None): return True if not max_pixels: return (batch_size, seq_len, is_prompt) in self.graphed_buckets - else: - #TODO: We might need to check both language bucket and multimodal bucket - # and return True only it's avialble, or return seperately. - return (max_pixels) in self.graphed_multimodal_buckets + #TODO: We might need to check both language bucket and multimodal bucket + # and return True only it's avialble, or return seperately. + return (max_pixels) in self.graphed_multimodal_buckets def _is_valid_bucket(self, bucket): return bucket[0] * bucket[1] <= self.max_num_batched_tokens @@ -2124,6 +2124,7 @@ def profile_run(self) -> None: _, max_seq_len = self.bucketing_ctx.get_max_prompt_shape() max_batch_size = min(self.max_num_seqs, self.max_num_batched_tokens // max_seq_len) + # Using batch_size 1 is profile multimodal models max_batch_size = max_batch_size if not self.model_is_mrope else 1 self.warmup_scenario( batch_size=max_batch_size, @@ -2145,7 +2146,16 @@ def warmup_scenario(self, is_lora_profile_run=False, temperature=0, max_pixels=None) -> None: - use_graphs = self._use_graphs(batch_size, seq_len, is_prompt, max_pixels) + if max_pixels: + # TODO: Find a better way to convert from raw pixel values to + # pixel_values from patches + max_pixels = max_pixels * 14**2 + use_graphs = self._use_graphs( + batch_size, + seq_len, + is_prompt, + max_pixels, + ) scenario_name = ("warmup_" f"{'prompt' if is_prompt else 'decode'}_" f"bs{batch_size}_" @@ -2326,7 +2336,7 @@ def _warmup_multimodal(self, kv_caches): kv_caches=kv_caches, is_pt_profiler_run=False, is_lora_profile_run=True, - max_pixels=max_pixels * 14 * 14) + max_pixels=max_pixels) def warmup_all_buckets(self, buckets, is_prompt, kv_caches): for i, (batch_size, seq_len) in enumerate(reversed(buckets)): @@ -2435,7 +2445,7 @@ def _warmup_multimodal_graph(self, seq_len=seq_len, is_prompt=True, kv_caches=kv_caches, - max_pixels=max_pixels * 14 * 14) + max_pixels=max_pixels) used_mem = align_workers(mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX) @@ -2804,6 +2814,22 @@ def _get_seq_ids(self, model_input): sg.seq_ids[0] for sg in model_input.sampling_metadata.seq_groups ]) + def _get_max_pixels_from_model_input(self, model_input): + if not model_input.multi_modal_kwargs or 'pixel_values' not in model_input.multi_modal_kwargs: + return None + pixel_values_list = model_input.multi_modal_kwargs['pixel_values'] + if isinstance(pixel_values_list, torch.Tensor): + pixel_values_list = [pixel_values_list] + assert isinstance(pixel_values_list, list) + model = self.get_model() + max_bucket_size = 0 + for pixel_values in pixel_values_list: + assert isinstance(pixel_values, torch.Tensor) + curr_num_pixels = pixel_values.shape[-2] + bucket_size = model.vision_buckets.get_multimodal_bucket(curr_num_pixels) + max_bucket_size = max(max_bucket_size, bucket_size) + return max_bucket_size + def _pad_to_max_num_seqs(self, tensor, value): padding_needed = self.max_num_seqs - tensor.size(0) if padding_needed: @@ -2901,14 +2927,11 @@ def execute_model( assert is_prompt is not None batch_size = input_tokens.size(0) seq_len = self._seq_len(attn_metadata) - if self.model_is_mrope and is_prompt and hasattr(model_input, 'multi_modal_kwargs') and \ - model_input.multi_modal_kwargs is not None and 'pixel_values' in model_input.multi_modal_kwargs: - max_pixel = model_input.multi_modal_kwargs['pixel_values'].shape[-2] - #print(f"max_pixel: {max_pixel}, count:{(model_input.input_tokens == 151655).sum().item()}, grid_thw: {model_input.multi_modal_kwargs['image_grid_thw']}") - else: - max_pixel = None - - use_graphs = self._use_graphs(batch_size, seq_len, is_prompt, max_pixels=max_pixel) + max_pixels = self._get_max_pixels_from_model_input(model_input) + use_graphs = self._use_graphs(batch_size=batch_size, + seq_len=seq_len, + is_prompt=is_prompt, + max_pixels=max_pixels) self._check_config(batch_size, seq_len, attn_metadata, warmup_mode) lora_mask: torch.Tensor = None From bbe1571b4949b8be32981b69c45ba64cbda1ac6e Mon Sep 17 00:00:00 2001 From: Gustavo Malkomes Date: Sat, 19 Apr 2025 00:11:24 +0000 Subject: [PATCH 23/38] set input_positions in text to be (3, seq_len) for mrope models --- vllm/utils.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/vllm/utils.py b/vllm/utils.py index da79625572cd..d6791723a832 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -828,14 +828,7 @@ def make_mrope_positions_tensor_with_pad( \ input_mrope_positions: List[List[List[int]]], max_prompt_len: int, pad: int) -> List[List[int]]: - # If no mrope positions, returns a flatten (seq_len,) - if all(mrope_position is None for mrope_position in input_mrope_positions): - return make_tensor_with_pad(input_positions, - max_len=max_prompt_len, - pad=0, - dtype=torch.long, - device='cpu').flatten() - # Otherwise, Qwen2.5-VL expects positions in a (3, seq_len) + # Qwen2.5-VL expects positions in a (3, seq_len) # we are going to pad each seq_data in the list # using either MRope values or regular position mrope_input_positions: List[List[int]] = [[] for _ in range(3)] From 73fadb4362cb6398a0dda7b3758961c1dc49e0d5 Mon Sep 17 00:00:00 2001 From: Gustavo Malkomes Date: Mon, 21 Apr 2025 16:48:47 +0000 Subject: [PATCH 24/38] linting --- vllm/worker/hpu_model_runner.py | 68 ++++++++++++++++----------------- 1 file changed, 34 insertions(+), 34 deletions(-) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 0a35c9db7b8d..0de92ec852ce 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -86,12 +86,13 @@ VLLM_MERGED_PREFILL = os.environ.get('VLLM_MERGED_PREFILL', 'false').lower() == 'true' DUMMY_TOKEN_ID = -1 - - ''' This class is used to bucket image tokens ''' + + class VisionBuckets(): + def __init__(self): envvar = os.environ.get('VLLM_MULTIMODAL_BUCKETS', "") if envvar == "": @@ -297,10 +298,8 @@ def __init__(self, model, vllm_config, layer_names): 'VLLM_QWEN_SPLIT_GRAPHS', 'false').lower() in ['1', 'true'] if not htorch.utils.internal.is_lazy() and self.split_graph: - logger.warning( - f"[Multimodal] HPU is not in Lazy Mode, " - f"split graph has not impact" - ) + logger.warning(f"[Multimodal] HPU is not in Lazy Mode, " + f"split graph has not impact") if htorch.utils.internal.is_lazy() and self.split_graph: logger.info("[Multimodal] Split Graph to Visual and Language") self.model.visual = htorch.hpu.wrap_in_hpu_graph( @@ -489,7 +488,7 @@ def forward(self, *args, **kwargs): bc.env_setting, "PT_COMPILE_ONLY_MODE", False) with compile_only_mode_context(): - #calculate embedding for multimodal + # always calculate embeddings for multimodal image_input = self.model._parse_and_validate_image_input( **kwargs) video_input = self.model._parse_and_validate_video_input( @@ -789,7 +788,8 @@ def __init__( self.use_merged_prefill, self.max_model_len) self.graphed_buckets: Set[Any] = set() - self.multimodal_buckets = [] #This should be use HPUBucketingContext << + self.multimodal_buckets = [ + ] #This should be use HPUBucketingContext << self.graphed_multimodal_buckets: Set[Any] = set() self._set_gc_threshold() @@ -952,7 +952,6 @@ def load_model(self) -> None: logger.info(msg) self.add_vision_buckets_to_model() - def _add_dummy_seq(self, seq_group_metadata_list, is_prompt): real_batch_size = len(seq_group_metadata_list) batch_size_padded = self.bucketing_ctx.get_padded_batch_size( @@ -1150,9 +1149,11 @@ def make_attn_bias(self, seq_lens, max_prompt_len, dtype): def move_to_device(self, tensor): return tensor if tensor is None else tensor.to(self.device, non_blocking=True) + ''' Right now Qwen2.5VL needs to know these buckets so it can do some things internally ''' + def add_vision_buckets_to_model(self): model = self.get_model() if supports_multimodal(model): @@ -2030,9 +2031,11 @@ def trim_attn_metadata(self, metadata: AttentionMetadata) -> object: return attention_metadata def create_dummy_multi_modal_seq_group_metadata(self, group_id, seq_len, - max_pixels, sampling_params, + max_pixels, + sampling_params, lora_request): - assert self.mm_registry.has_processor(self.model_config), 'Multimodal Warmup needs a processor' + assert self.mm_registry.has_processor( + self.model_config), 'Multimodal Warmup needs a processor' tokenizer = cached_get_tokenizer( self.model_config.tokenizer, trust_remote_code=self.model_config.trust_remote_code, @@ -2161,8 +2164,7 @@ def warmup_scenario(self, f"bs{batch_size}_" f"seq{seq_len}_" f"multimodal{max_pixels if max_pixels else 'F'}_" - f"graphs{'T' if use_graphs else 'F'}" - ) + f"graphs{'T' if use_graphs else 'F'}") # This represents the maximum number of different requests # that will have unique loras, an therefore the max amount of memory # consumption create dummy lora request copies from the lora request @@ -2198,8 +2200,7 @@ def warmup_scenario(self, lora_request=dummy_lora_requests_per_seq[i] if dummy_lora_requests_per_seq else None, max_pixels=max_pixels, - temperature=temperature - ) for i in range(batch_size) + temperature=temperature) for i in range(batch_size) ] else: # FIXME: seq_len is actually number of blocks @@ -2327,9 +2328,8 @@ def _warmup_multimodal(self, kv_caches): num_candidates = len(self.multimodal_buckets) for i, max_pixels in enumerate(self.multimodal_buckets): - self.log_warmup_multimodal(phase, i, num_candidates, - batch_size, seq_len, - max_pixels) + self.log_warmup_multimodal(phase, i, num_candidates, batch_size, + seq_len, max_pixels) self.warmup_scenario(batch_size=batch_size, seq_len=seq_len, is_prompt=True, @@ -2411,10 +2411,10 @@ def warmup_graphs(self, return total_mem, total_batch_seq, captured_all def _warmup_multimodal_graph(self, - kv_caches, - available_mem, - starting_mem=0, - total_batch_seq=0.001): + kv_caches, + available_mem, + starting_mem=0, + total_batch_seq=0.001): if not supports_multimodal(self.get_model()): return None total_mem = starting_mem @@ -2436,16 +2436,15 @@ def _warmup_multimodal_graph(self, if graphed_multimodal_bucket in self.graphed_multimodal_buckets: continue self.graphed_multimodal_buckets.add(graphed_multimodal_bucket) - self.log_warmup_multimodal(phase, idx, num_candidates, - batch_size, seq_len, max_pixels) + self.log_warmup_multimodal(phase, idx, num_candidates, batch_size, + seq_len, max_pixels) with HabanaMemoryProfiler() as mem_prof: - self.warmup_scenario( - batch_size=batch_size, - seq_len=seq_len, - is_prompt=True, - kv_caches=kv_caches, - max_pixels=max_pixels) + self.warmup_scenario(batch_size=batch_size, + seq_len=seq_len, + is_prompt=True, + kv_caches=kv_caches, + max_pixels=max_pixels) used_mem = align_workers(mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX) @@ -2466,14 +2465,14 @@ def log_graph_warmup_summary(self, buckets, is_prompt, total_mem): f'({100 * len(graphed) / num_candidates:.1f}%) ' f'used_mem:{format_bytes(total_mem)} ' f'buckets:{sorted(list(graphed))}') - logger.info(msg) + logger.info(msg) if "Prompt" in phase and len(self.multimodal_buckets) > 0: phase = "Graph/Multimodal" num_candidates = len(self.multimodal_buckets) mm_graphed = self.graphed_multimodal_buckets msg = (f'{phase} captured:{len(mm_graphed)} ' - f'({100 * len(mm_graphed) / num_candidates:.1f}%) ' - f'buckets:{sorted(list(mm_graphed))}') + f'({100 * len(mm_graphed) / num_candidates:.1f}%) ' + f'buckets:{sorted(list(mm_graphed))}') logger.info(msg) @torch.inference_mode() @@ -2826,7 +2825,8 @@ def _get_max_pixels_from_model_input(self, model_input): for pixel_values in pixel_values_list: assert isinstance(pixel_values, torch.Tensor) curr_num_pixels = pixel_values.shape[-2] - bucket_size = model.vision_buckets.get_multimodal_bucket(curr_num_pixels) + bucket_size = model.vision_buckets.get_multimodal_bucket( + curr_num_pixels) max_bucket_size = max(max_bucket_size, bucket_size) return max_bucket_size From 8aff50176edecf0d7bb521bb6a2055885c6b48b2 Mon Sep 17 00:00:00 2001 From: Gustavo Malkomes Date: Mon, 21 Apr 2025 16:51:33 +0000 Subject: [PATCH 25/38] always compute embeddings for qwen2.5vl, even text --- vllm/worker/hpu_model_runner.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 0de92ec852ce..3e4a78386a44 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -494,14 +494,11 @@ def forward(self, *args, **kwargs): video_input = self.model._parse_and_validate_video_input( **kwargs) - if image_input is None and video_input is None: - inputs_embeds = None - else: - inputs_embeds = self.model.get_input_embeddings_v0( - input_ids, - image_input=image_input, - video_input=video_input) - input_ids = None + inputs_embeds = self.model.get_input_embeddings_v0( + input_ids, + image_input=image_input, + video_input=video_input) + input_ids = None kwargs.update({ "input_ids": input_ids, From b2c020edbe5f2192209234998b41a578fd55b1e3 Mon Sep 17 00:00:00 2001 From: Gustavo Malkomes Date: Mon, 21 Apr 2025 22:59:56 +0000 Subject: [PATCH 26/38] simplify dummy_multi_modal replace: max_pixels -> num_patches --- vllm/worker/hpu_model_runner.py | 127 ++++++++++++++++---------------- 1 file changed, 63 insertions(+), 64 deletions(-) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 3e4a78386a44..6ec73e259ff7 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -78,7 +78,7 @@ # Use caution when updating them! _PAD_SLOT_ID = 0 _PAD_BLOCK_ID = 0 -_UNSET_MAX_PIXELS = 9999999 +_UNSET_NUM_PATCHES = 9999999 LORA_WARMUP_RANK = 8 VLLM_DELAYED_SAMPLING = os.environ.get('VLLM_DELAYED_SAMPLING', @@ -1045,16 +1045,16 @@ def get_model(self) -> torch.nn.Module: return self.model.model return self.model - def _use_graphs(self, batch_size, seq_len, is_prompt, max_pixels=None): + def _use_graphs(self, batch_size, seq_len, is_prompt, num_patches=None): if self.enforce_eager: return False if self.skip_warmup: return True - if not max_pixels: + if not num_patches: return (batch_size, seq_len, is_prompt) in self.graphed_buckets #TODO: We might need to check both language bucket and multimodal bucket # and return True only it's avialble, or return seperately. - return (max_pixels) in self.graphed_multimodal_buckets + return (num_patches) in self.graphed_multimodal_buckets def _is_valid_bucket(self, bucket): return bucket[0] * bucket[1] <= self.max_num_batched_tokens @@ -2027,59 +2027,63 @@ def trim_attn_metadata(self, metadata: AttentionMetadata) -> object: ]) return attention_metadata - def create_dummy_multi_modal_seq_group_metadata(self, group_id, seq_len, - max_pixels, + def create_dummy_multi_modal_seq_group_metadata(self, group_id, + num_patches, sampling_params, lora_request): - assert self.mm_registry.has_processor( - self.model_config), 'Multimodal Warmup needs a processor' - tokenizer = cached_get_tokenizer( - self.model_config.tokenizer, - trust_remote_code=self.model_config.trust_remote_code, - ) - processor = self.mm_registry.create_processor(self.model_config, - tokenizer) - mm_counts = self.mm_registry.get_mm_limits_per_prompt( - self.model_config) - factory = processor.dummy_inputs - processor_inputs = factory.get_dummy_processor_inputs( - seq_len=seq_len, - mm_counts=mm_counts, - ) - - hf_processor_mm_kwargs = dict(processor_inputs.hf_processor_mm_kwargs) - if max_pixels and max_pixels != _UNSET_MAX_PIXELS: - hf_processor_mm_kwargs["max_pixels"] = max_pixels - - mm_inputs = processor.apply( - prompt=processor_inputs.prompt_text, - mm_data=processor_inputs.mm_data, - hf_processor_mm_kwargs=hf_processor_mm_kwargs, - ) - - prompt_token_ids = mm_inputs["prompt_token_ids"] - placeholders_by_modality = mm_inputs["mm_placeholders"] - num_tokens_to_extend = seq_len - len(prompt_token_ids) - prompt_token_ids.extend([0] * (num_tokens_to_extend)) + assert self.model_is_mrope, "Warmup compatible with Qwen2vl models" + if num_patches == _UNSET_NUM_PATCHES: + # # only half of the total number of tokens should be from image + # num_image_tokens = self.model_config.max_model_len // 2 + # # get the number of patches from the num of image tokens + # num_patches = num_image_tokens * 4 + + # Using the largest bucket + num_patches = self.get_model().vision_buckets.multimodal_buckets[-1] + + # get number of tokens from num_patches using merger + # vision_config.spatial_merge_size + num_image_tokens = num_patches // 4 # TODO use the spatial_merge_size ** 2 insted of 4 + + image_token_id = self.get_model().config.image_token_id + prompt_token_ids = [image_token_id] * num_image_tokens + prompt_token_ids_array = array('l', prompt_token_ids) # noqa: F821 + placeholders_by_modality = { + 'image': [{ + 'offset': 0, + 'length': len(prompt_token_ids) + }] + } seq_data = SequenceData.from_seqs(prompt_token_ids) + seq_data = SequenceData(prompt_token_ids_array) + + image_h = int(math.sqrt(num_patches)) + image_grid_thw = torch.tensor([1, image_h, image_h]) + pixel_values = torch.randn(image_grid_thw.prod(), 1176) # TODO: figure out the variable name + multi_modal_data = { + "pixel_values": pixel_values, + "image_grid_thw": image_grid_thw, + } + multi_modal_data = MultiModalKwargs(multi_modal_data) - return SequenceGroupMetadata( + seq_group = SequenceGroupMetadata( request_id=str(group_id), is_prompt=True, seq_data={group_id: seq_data}, sampling_params=sampling_params, block_tables=None, lora_request=lora_request[group_id] if lora_request else None, - multi_modal_data=mm_inputs["mm_kwargs"], + multi_modal_data=multi_modal_data, multi_modal_placeholders=placeholders_by_modality, ) + return seq_group def create_dummy_seq_group_metadata(self, group_id, seq_len, is_prompt, lora_request=None, - max_pixels=None, + num_patches=None, temperature=0): if self.is_pooler: sampling_params = None @@ -2087,11 +2091,10 @@ def create_dummy_seq_group_metadata(self, sampling_params = SamplingParams(temperature=temperature) num_blocks = math.ceil(seq_len / self.block_size) seq_len = max(seq_len, 1) - if is_prompt and max_pixels: + if is_prompt and self.model_is_mrope and num_patches: return self.create_dummy_multi_modal_seq_group_metadata( group_id=group_id, - seq_len=seq_len, - max_pixels=max_pixels, + num_patches=num_patches, sampling_params=sampling_params, lora_request=lora_request, ) @@ -2132,7 +2135,7 @@ def profile_run(self) -> None: is_prompt=True, kv_caches=kv_caches, is_pt_profiler_run=False, - max_pixels=_UNSET_MAX_PIXELS, + num_patches=_UNSET_NUM_PATCHES, is_lora_profile_run=True, ) return @@ -2145,22 +2148,18 @@ def warmup_scenario(self, is_pt_profiler_run=False, is_lora_profile_run=False, temperature=0, - max_pixels=None) -> None: - if max_pixels: - # TODO: Find a better way to convert from raw pixel values to - # pixel_values from patches - max_pixels = max_pixels * 14**2 + num_patches=None) -> None: use_graphs = self._use_graphs( batch_size, seq_len, is_prompt, - max_pixels, + num_patches, ) scenario_name = ("warmup_" f"{'prompt' if is_prompt else 'decode'}_" f"bs{batch_size}_" f"seq{seq_len}_" - f"multimodal{max_pixels if max_pixels else 'F'}_" + f"multimodal{num_patches if num_patches else 'F'}_" f"graphs{'T' if use_graphs else 'F'}") # This represents the maximum number of different requests # that will have unique loras, an therefore the max amount of memory @@ -2196,7 +2195,7 @@ def warmup_scenario(self, is_prompt, lora_request=dummy_lora_requests_per_seq[i] if dummy_lora_requests_per_seq else None, - max_pixels=max_pixels, + num_patches=num_patches, temperature=temperature) for i in range(batch_size) ] else: @@ -2304,14 +2303,14 @@ def log_warmup(self, phase, i, max_i, batch_size, seq_len): logger.info(msg) def log_warmup_multimodal(self, phase, i, max_i, batch_size, seq_len, - max_pixels): + num_patches): free_mem = format_bytes( HabanaMemoryProfiler.current_free_device_memory()) dim = "seq_len" msg = (f"[Warmup][{phase}][{i+1}/{max_i}] " f"batch_size:{batch_size} " f"{dim}:{seq_len} " - f"max_pixels:{max_pixels} " + f"num_patches:{num_patches} " f"free_mem:{free_mem}") logger.info(msg) @@ -2324,16 +2323,16 @@ def _warmup_multimodal(self, kv_caches): phase = 'Multimodal' num_candidates = len(self.multimodal_buckets) - for i, max_pixels in enumerate(self.multimodal_buckets): + for i, num_patches in enumerate(self.multimodal_buckets): self.log_warmup_multimodal(phase, i, num_candidates, batch_size, - seq_len, max_pixels) + seq_len, num_patches) self.warmup_scenario(batch_size=batch_size, seq_len=seq_len, is_prompt=True, kv_caches=kv_caches, is_pt_profiler_run=False, is_lora_profile_run=True, - max_pixels=max_pixels) + num_patches=num_patches) def warmup_all_buckets(self, buckets, is_prompt, kv_caches): for i, (batch_size, seq_len) in enumerate(reversed(buckets)): @@ -2419,29 +2418,29 @@ def _warmup_multimodal_graph(self, phase = f'Graph/Multimodal' num_candidates = len(self.multimodal_buckets) captured_all = True - for idx, max_pixels in enumerate(self.multimodal_buckets): + for idx, num_patches in enumerate(self.multimodal_buckets): batch_size = 1 # Note: Multimodal buckets are indepedent of batch_size _, max_seq_len = self.bucketing_ctx.get_max_prompt_shape() seq_len = max_seq_len - batch_seq = 1 * max_pixels + batch_seq = 1 * num_patches # Graph memory usage is proportional to seq dimension in a batch mem_estimate = batch_seq / total_batch_seq * total_mem if mem_estimate >= available_mem: captured_all = False continue - graphed_multimodal_bucket = max_pixels + graphed_multimodal_bucket = num_patches if graphed_multimodal_bucket in self.graphed_multimodal_buckets: continue self.graphed_multimodal_buckets.add(graphed_multimodal_bucket) self.log_warmup_multimodal(phase, idx, num_candidates, batch_size, - seq_len, max_pixels) + seq_len, num_patches) with HabanaMemoryProfiler() as mem_prof: self.warmup_scenario(batch_size=batch_size, seq_len=seq_len, is_prompt=True, kv_caches=kv_caches, - max_pixels=max_pixels) + num_patches=num_patches) used_mem = align_workers(mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX) @@ -2810,7 +2809,7 @@ def _get_seq_ids(self, model_input): sg.seq_ids[0] for sg in model_input.sampling_metadata.seq_groups ]) - def _get_max_pixels_from_model_input(self, model_input): + def _get_num_patches_from_model_input(self, model_input): if not model_input.multi_modal_kwargs or 'pixel_values' not in model_input.multi_modal_kwargs: return None pixel_values_list = model_input.multi_modal_kwargs['pixel_values'] @@ -2924,11 +2923,11 @@ def execute_model( assert is_prompt is not None batch_size = input_tokens.size(0) seq_len = self._seq_len(attn_metadata) - max_pixels = self._get_max_pixels_from_model_input(model_input) + num_patches = self._get_num_patches_from_model_input(model_input) use_graphs = self._use_graphs(batch_size=batch_size, seq_len=seq_len, is_prompt=is_prompt, - max_pixels=max_pixels) + num_patches=num_patches) self._check_config(batch_size, seq_len, attn_metadata, warmup_mode) lora_mask: torch.Tensor = None From 372e793657849a7f9cb18c1ad880ca0129799b38 Mon Sep 17 00:00:00 2001 From: Jimin Ha Date: Mon, 21 Apr 2025 22:34:53 +0000 Subject: [PATCH 27/38] Add VLLM_GRAPH_MULTIMODAL_PROMPT_RATIO --- vllm/worker/hpu_model_runner.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 6ec73e259ff7..3c6430ceedbf 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -2350,6 +2350,21 @@ def warmup_graphs(self, available_mem, starting_mem=0, total_batch_seq=0.001): + + if is_prompt and supports_multimodal(self.get_model()) and is_prompt: + multimodal_prompt_graph_mem_ratio = float( + os.environ.get('VLLM_GRAPH_MULTIMODAL_PROMPT_RATIO', '0.3')) + multimodal_avail_mem = (multimodal_prompt_graph_mem_ratio * + available_mem) + available_mem = (available_mem - multimodal_avail_mem) + msg = ( + f"Using {format_bytes(multimodal_avail_mem)} for multimodal prompt and " + f"{format_bytes(available_mem)} for text prompt " + f"(VLLM_GRAPH_MULTIMODAL_PROMPT_RATIO={multimodal_prompt_graph_mem_ratio})") + logger.info(msg) + else: + multimodal_avail_mem = 0 + total_mem = starting_mem idx = 0 phase = f'Graph/{"Prompt" if is_prompt else "Decode"}' @@ -2397,10 +2412,11 @@ def warmup_graphs(self, mm_outputs = \ self._warmup_multimodal_graph( kv_caches=kv_caches, - available_mem=available_mem, - starting_mem=total_mem, + available_mem=multimodal_avail_mem, + starting_mem=0, total_batch_seq=total_batch_seq, ) + if mm_outputs is not None: total_mem, total_batch_seq, mm_captured_all = mm_outputs captured_all = captured_all and mm_captured_all From 458b9fa002000659adf2d0602b99ec768eb40da4 Mon Sep 17 00:00:00 2001 From: Sayantan Sarkar Date: Tue, 22 Apr 2025 20:55:19 +0000 Subject: [PATCH 28/38] Clean up some vars --- vllm/model_executor/models/qwen2_5_vl.py | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 8cf898022d82..896b921ba345 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -108,24 +108,19 @@ def forward(q, k, v, mask, q_block_size): def create_block_diagonal_attention_mask_outerprod(indices): maxsize = indices[-1] range_to_max_for_each_img = torch.arange(maxsize, device=indices.device).unsqueeze(0).repeat(indices.shape[0]-1,1) - yy = range_to_max_for_each_img < indices[1:].unsqueeze(1) - zz = range_to_max_for_each_img >= indices[:-1].unsqueeze(1) - xx = torch.logical_and(yy, zz).float() + lesser = range_to_max_for_each_img < indices[1:].unsqueeze(1) + greater_eq = range_to_max_for_each_img >= indices[:-1].unsqueeze(1) + range_indices = torch.logical_and(lesser, greater_eq).float() # can reduce sum externally or as batchmatmul #TODO: einsum with tensor dimension too big doesn't work. Register max size error. #We can always move to CPU for all einsum without shape checking if perf impact is minimal. - if xx.shape[-1] > 40000: - print("einsum running on CPU : ", xx.shape) - xx = xx.to("cpu") - res = torch.einsum('bi,bj->ij', xx, xx) - #breakpoint() + if range_indices.shape[-1] > 40000: + print("einsum running on CPU : ", range_indices.shape) + range_indices = range_indices.to("cpu") + res = torch.einsum('bi,bj->ij', range_indices, range_indices) res = res.to("hpu") - #res = torch.sum(res, dim=0) else: - #res = torch.sum(torch.einsum('bi,bj->bij', xx, xx), dim=0) - res = torch.einsum('bi,bj->ij', xx, xx) - #print('.....MASK SHAPE', res.shape, indices) - #res = torch.einsum('bi,bj->ij', xx.float(), xx.float()) + res = torch.einsum('bi,bj->ij', range_indices, range_indices) return res.bool() def expand_to_max(indices, max_num_images): From 36213b9bccdda09407def852a4c9259554daa872 Mon Sep 17 00:00:00 2001 From: Jimin Ha Date: Tue, 22 Apr 2025 20:13:44 +0000 Subject: [PATCH 29/38] Remove SPLIT flag for Qwen --- vllm/worker/hpu_model_runner.py | 53 +++++++++++++-------------------- 1 file changed, 21 insertions(+), 32 deletions(-) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 119c8e6553e4..2538a8ac2ed0 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -288,17 +288,13 @@ def __init__(self, model, vllm_config, layer_names, is_causal): model_config = getattr(self.model, "config", None) self.model_is_mrope = uses_mrope(model_config) - # For qwen2.5-VL model, we wrap visual model with disable_tensor_cache - # off due to handling of grid_thw. For langauge model, we wrap it with - # disable_tensor_cache on to save memory. Here we can either wrap it with - # self.model or self.model.language_model.model. - self.split_graph = self.model_is_mrope and os.getenv( - 'VLLM_QWEN_SPLIT_GRAPHS', 'false').lower() in ['1', 'true'] - - if not htorch.utils.internal.is_lazy() and self.split_graph: + # This applies exclusively to Qwen2/2.5-VL model only(which use mrope) + # We split the model into visual and language components and wrap them separately + # with HPU graph. This is to ensure that we keeps the static and dynamic parts distint. + if not htorch.utils.internal.is_lazy() and self.model_is_mrope: logger.warning(f"[Multimodal] HPU is not in Lazy Mode, " f"split graph has not impact") - if htorch.utils.internal.is_lazy() and self.split_graph: + if htorch.utils.internal.is_lazy() and self.model_is_mrope: logger.info("[Multimodal] Split Graph to Visual and Language") self.model.visual = htorch.hpu.wrap_in_hpu_graph( self.model.visual, disable_tensor_cache=True) @@ -464,24 +460,20 @@ def forward(self, *args, **kwargs): if self.layer_names is not None and not self.model_is_mrope: self._prepare_cos_sin(kwargs['positions']) - if self.model_is_mrope: # and self.split_graph: - if self.split_graph: - # Carry bypass_hpu_graphs to visual model forward. - bypass_hpu_graphs = kwargs.get('bypass_hpu_graphs', False) - self.model.visual.forward = functools.partial( - self.model.visual.forward, - bypass_hpu_graphs=bypass_hpu_graphs) - self.model.language_model.model.forward = functools.partial( - self.model.language_model.model.forward, - bypass_hpu_graphs=bypass_hpu_graphs) - #self.model.forward = functools.partial( - # self.model.forward, bypass_hpu_graphs=bypass_hpu_graphs) - - # For Qwen2.5-VL multimodal embedding, - # This embedding part should be always executed with PT_COMPILE_ONLY_MODE off - # at all time. We are turning it off here since it will be on during warmup run. - # Also, we are moving this code block to here from model.forward() since we don't want - # to wrap this with hpu_graph. This block has issue with disable_tensor_cache=true. + if self.model_is_mrope: + bypass_hpu_graphs = kwargs.get('bypass_hpu_graphs', False) + self.model.visual.forward = functools.partial( + self.model.visual.forward, + bypass_hpu_graphs=bypass_hpu_graphs) + self.model.language_model.model.forward = functools.partial( + self.model.language_model.model.forward, + bypass_hpu_graphs=bypass_hpu_graphs) + + # For Qwen2.5-VL multimodal embedding, this embedding part should be executed + # with PT_COMPILE_ONLY_MODE off at all times due to it's dynamicity. + # During warmup, this is ON by default, so we are turning it off here. + # Also, we moved this code block from model.forward() since we want to get + # embedding before pass it to model which is also aligned with VLLM V1. compile_only_mode_context = functools.partial( bc.env_setting, "PT_COMPILE_ONLY_MODE", False) @@ -963,12 +955,9 @@ def _add_dummy_seq(self, seq_group_metadata_list, is_prompt): return seq_group_metadata_list, real_batch_size, batch_size_padded def _maybe_wrap_in_hpu_graph(self, *args, **kwargs): - self.split_graph = self.model_is_mrope and os.getenv( - 'VLLM_QWEN_SPLIT_GRAPHS', 'false').lower() in ['1', 'true'] - if htorch.utils.internal.is_lazy() and not self.split_graph: + if htorch.utils.internal.is_lazy() and not self.model_is_mrope: return htorch.hpu.wrap_in_hpu_graph(HpuModelAdapter( - *args, **kwargs), - disable_tensor_cache=True) + *args, **kwargs), disable_tensor_cache=True) else: return HpuModelAdapter(*args, **kwargs) From 69e31118a85dea6242c2997adc78af9bb9bdef7d Mon Sep 17 00:00:00 2001 From: Gustavo Malkomes Date: Tue, 22 Apr 2025 22:12:00 +0000 Subject: [PATCH 30/38] Using Qwen2_5_VisionTransformerStaticShape --- vllm/model_executor/models/qwen2_5_vl.py | 421 ++++++++++++++--------- 1 file changed, 249 insertions(+), 172 deletions(-) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 896b921ba345..08f15cac85e6 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -266,7 +266,6 @@ def __init__( self.num_attention_heads_per_partition = dist_utils.divide( num_heads, self.tp_size) - self.qkv = ColumnParallelLinear(input_size=embed_dim, output_size=3 * projection_size, quant_config=quant_config, @@ -311,8 +310,7 @@ def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: def forward( self, x: torch.Tensor, - cu_seqlens: Optional[torch.Tensor], - fullattn_mask: Optional[torch.Tensor], + cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor, ) -> torch.Tensor: # [s, b, c] --> [s, b, head * 3 * head_dim] @@ -354,14 +352,24 @@ def forward( context_layer = rearrange(output, "(b s) ... -> b s ...", b=batch_size) - elif self.attn_backend == _Backend.TORCH_SDPA: - # Execute attention entry by entry for speed & less VRAM. - if cu_seqlens is None: + elif self.attn_backend == _Backend.TORCH_SDPA and is_hpu: + from habana_frameworks.torch.hpex.kernels import FusedSDPA + + # We are abusing the variable name cu_seqlens + # to represent the mask for full attention, + # if the mask if None we are doing window attention + fullattn_mask = cu_seqlens + + if fullattn_mask is None: # performs window attention + # we assume image is 112 aligned in both h/w dims + # in other words, x % 64 = 0 + # that simplifies the slicing of window attention + # in patches of 64 outputs = [] - cu_seqlens = list(range(0, x.shape[0]+1, 64)) # assuming x%64=0 (image is 112 aligned in both h/w dims) + cu_seqlens = list(range(0, x.shape[0]+1, 64)) for i in range(1, len(cu_seqlens)): - #TODO: Check if number 100 is good - #For large image, we add mark step here for every 100th step to make compile time shorter + # For large image, we add mark step here + # for every 100th step to make compile time shorter if i % 100 == 0: htcore.mark_step() start_idx = cu_seqlens[i - 1] @@ -371,35 +379,48 @@ def forward( v_i = v[:, start_idx:end_idx] q_i, k_i, v_i = (rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]) - if is_hpu: - output_i = FusedSDPA.apply(q_i, k_i, v_i, None, 0.0) - else: - output_i = F.scaled_dot_product_attention(q_i, - k_i, - v_i, - dropout_p=0.0) + output_i = FusedSDPA.apply(q_i, k_i, v_i, None, 0.0) output_i = rearrange(output_i, "b h s d -> b s h d ") outputs.append(output_i) context_layer = torch.cat(outputs, dim=1) else: + # performs full attention using the previous computed mask fullatt_block_attn_mask = fullattn_mask q1, k1, v1 = (rearrange(x, "b s h d -> b h s d")for x in [q, k, v]) - - - (batch_size, n_heads, seq_len_N_t, head_dim_qk) = q1.shape - (batch_size, n_heads, seq_len_N_s, head_dim_qk) = k1.shape + (batch_size, _, seq_len_N_t, _) = q1.shape + (batch_size, _, seq_len_N_s, _) = k1.shape mask_shape = (batch_size, 1, seq_len_N_t, seq_len_N_s) - attn_mask = fullatt_block_attn_mask.reshape(batch_size, 1, seq_len_N_t, seq_len_N_s, -1)[:, :, :, :, 0] + attn_mask = fullatt_block_attn_mask.reshape( + batch_size, + 1, + seq_len_N_t, + seq_len_N_s, + -1 + )[:, :, :, :, 0] # reshapes the mask to be Bx1xNxN assert attn_mask.shape == mask_shape - - #TODO:after 1by1 branch, even with long sequence, FusedSDPA is much faster - # Setting the number here to the max number we get in profile_run. - #print('BATCHSIZE:: ', batch_size, '.........................................') - if q1.shape[2] <= 65536: # this crossover point should be measured - fused_out = FusedSDPA.apply(q1, k1, v1, attn_mask, 0.0) # Bx1xNxN + if q1.shape[2] <= 65536: # this crossover point should be investigated + fused_out = FusedSDPA.apply(q1, k1, v1, attn_mask, 0.0) else: fused_out = AttentionLongSequence.forward(q1, k1, v1, attn_mask, 64) context_layer = rearrange(fused_out, "b h s d -> b s h d ") + elif self.attn_backend == _Backend.TORCH_SDPA: + # Execute attention entry by entry for speed & less VRAM. + outputs = [] + for i in range(1, len(cu_seqlens)): + start_idx = cu_seqlens[i - 1] + end_idx = cu_seqlens[i] + q_i = q[:, start_idx:end_idx] + k_i = k[:, start_idx:end_idx] + v_i = v[:, start_idx:end_idx] + q_i, k_i, v_i = (rearrange(x, "b s h d -> b h s d") + for x in [q_i, k_i, v_i]) + output_i = F.scaled_dot_product_attention(q_i, + k_i, + v_i, + dropout_p=0.0) + output_i = rearrange(output_i, "b h s d -> b s h d ") + outputs.append(output_i) + context_layer = torch.cat(outputs, dim=1) elif self.attn_backend == _Backend.XFORMERS: from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalMask @@ -446,13 +467,10 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.mlp") - def forward(self, x: torch.Tensor, - cu_seqlens: torch.Tensor, - fullattn_mask: Optional[torch.Tensor], + def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor) -> torch.Tensor: x = x + self.attn(self.norm1(x), cu_seqlens=cu_seqlens, - fullattn_mask=fullattn_mask, rotary_pos_emb=rotary_pos_emb) x = x + self.mlp(self.norm2(x)) return x @@ -632,9 +650,7 @@ def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: self.spatial_merge_size, w // self.spatial_merge_size, self.spatial_merge_size, - ).permute( - 0, 2, 1, - 3).flatten() + ).permute(0, 2, 1, 3).flatten() wpos_ids = wpos_ids.reshape( h // self.spatial_merge_size, self.spatial_merge_size, @@ -684,7 +700,11 @@ def get_window_index(self, grid_thw): window_index = torch.cat(window_index, dim=0) return window_index, cu_window_seqlens - def pre_attn(self, x: torch.Tensor, grid_thw: torch.Tensor): + def forward( + self, + x: torch.Tensor, + grid_thw: torch.Tensor, + ) -> torch.Tensor: # patchify hidden_states = x.to(device=self.device, dtype=self.dtype) hidden_states = self.patch_embed(hidden_states) @@ -695,28 +715,12 @@ def pre_attn(self, x: torch.Tensor, grid_thw: torch.Tensor): # windows attention window_index, cu_window_seqlens = self.get_window_index(grid_thw) - if is_hpu: - # NOTE: unique_consecutive is a dynamic operation - # we are using `remove_duplicates_cpu` instead - def remove_duplicates_cpu(a): - return [ - a[i] for i in range(len(a)) if i == 0 or a[i - 1] != a[i] - ] - - cu_window_seqlens = remove_duplicates_cpu(cu_window_seqlens) - cu_window_seqlens = torch.tensor( - cu_window_seqlens, - device=hidden_states.device, - dtype=grid_thw.dtype - if torch.jit.is_tracing() else torch.int32) - - else: - cu_window_seqlens = torch.tensor( - cu_window_seqlens, - device=hidden_states.device, - dtype=grid_thw.dtype - if torch.jit.is_tracing() else torch.int32) - cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) + cu_window_seqlens = torch.tensor( + cu_window_seqlens, + device=hidden_states.device, + dtype=grid_thw.dtype + if torch.jit.is_tracing() else torch.int32) + cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) seq_len, _ = hidden_states.size() hidden_states = hidden_states.reshape( @@ -727,43 +731,26 @@ def remove_duplicates_cpu(a): seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) rotary_pos_emb = rotary_pos_emb[window_index, :, :] rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + # compute cu_seqlens cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( dim=0, dtype=torch.int32) cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0) - return hidden_states, rotary_pos_emb, cu_seqlens, cu_window_seqlens, window_index - def forward( - self, - x: torch.Tensor, - cu_seqlens: torch.Tensor, - fullattn_mask: Optional[torch.Tensor], - rotary_pos_emb: torch.Tensor) -> torch.Tensor: - if is_hpu: - assert x.shape[0]%64 == 0, "Expect inputs to be 112x112 aligned. Please align before sending image or use this version of transformer that does the resizing/alignment automatically: pip install git+https://github.com/malkomes/transformers.git@e4269f72aebb00b82cc232866e6565597f6ceacf" - hidden_states = x.unsqueeze(1) + # transformers + hidden_states = hidden_states.unsqueeze(1) for layer_num, blk in enumerate(self.blocks): - htcore.mark_step() - #TODO: now we premake fullattn_mask, we don't need to pass cu_seqlens - #but keep it here for now since other ATTN is using this argument. Need to clean code. if layer_num in self.fullatt_block_indexes: cu_seqlens_now = cu_seqlens else: - cu_seqlens_now = None + cu_seqlens_now = cu_window_seqlens hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens_now, - fullattn_mask=fullattn_mask, rotary_pos_emb=rotary_pos_emb) # adapter - - return hidden_states - - def post_attn(self, hidden_states: torch.Tensor, - window_index: torch.Tensor): hidden_states = self.merger(hidden_states) reverse_indices = torch.argsort(window_index) - hidden_states = hidden_states[reverse_indices, :] return hidden_states @@ -796,6 +783,168 @@ def load_weights(self, weights: Iterable[Tuple[str, loaded_params.add(name) return loaded_params +class Qwen2_5_VisionTransformerStaticShape(Qwen2_5_VisionTransformer): + """ + Here we overwrite some of the methods of Qwen2_5_VisionTransformer + to make the model more friendly to static shapes. Specifically, + we split the forward method into: + - pre_attn (dynamic) + - forward (static shape) + - post_attn (dynamic) + and we should call get_image_embeds instead of forward, allowing + the forward method ro run with HPU_Graphs, whereas the + pre_attn and post_attn methods are allow to be dynamic. + """ + + def pad_multimodal_data(self, pixel_values, image_grid_thw, vision_buckets): + assert pixel_values.shape[ + 0] % 64 == 0, '[testing version] needs 64 aligned resolution' + + desired_number_of_pixels = vision_buckets.get_multimodal_bucket(pixel_values.shape[0]) + padding_len = desired_number_of_pixels - pixel_values.shape[0] + if padding_len <= 0: + return pixel_values, image_grid_thw + + logger.info( + f"[MM_BUCKETING] Padding current number pixel {pixel_values.shape[0]} to {desired_number_of_pixels}" + ) + # needs to make sure padding_len is even + assert padding_len % 64 == 0, '[testing version] padding needs to be multiple of 64' + + constant_value = -100 + pixel_values = torch.cat([ + pixel_values, + torch.ones((padding_len, pixel_values.shape[1]), device=pixel_values.device) * constant_value + ]) + + image_grid_thw = torch.cat( + [image_grid_thw, + torch.tensor([[1, 8, padding_len // 8]], device=image_grid_thw.device)]) + + assert image_grid_thw.prod(-1).sum() == desired_number_of_pixels + return pixel_values, image_grid_thw + + def pre_attn(self, x: torch.Tensor, grid_thw: torch.Tensor): + # patchify + hidden_states = x.to(device=self.device, dtype=self.dtype) + hidden_states = self.patch_embed(hidden_states) + + # compute position embedding + rotary_pos_emb = self.rot_pos_emb(grid_thw) + + # windows attention + window_index, cu_window_seqlens = self.get_window_index(grid_thw) + + # NOTE: unique_consecutive is a dynamic operation + # we are using `remove_duplicates_cpu` instead + def remove_duplicates_cpu(a): + return [ + a[i] for i in range(len(a)) if i == 0 or a[i - 1] != a[i] + ] + + cu_window_seqlens = remove_duplicates_cpu(cu_window_seqlens) + cu_window_seqlens = torch.tensor( + cu_window_seqlens, + device=hidden_states.device, + dtype=grid_thw.dtype + if torch.jit.is_tracing() else torch.int32) + + seq_len, _ = hidden_states.size() + hidden_states = hidden_states.reshape( + seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + hidden_states = hidden_states[window_index, :, :] + hidden_states = hidden_states.reshape(seq_len, -1) + rotary_pos_emb = rotary_pos_emb.reshape( + seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + rotary_pos_emb = rotary_pos_emb[window_index, :, :] + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], + grid_thw[:, 0]).cumsum( + dim=0, dtype=torch.int32) + cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0) + return hidden_states, rotary_pos_emb, cu_seqlens, cu_window_seqlens, window_index + + def forward( + self, + x: torch.Tensor, + fullattn_mask: Optional[torch.Tensor], + rotary_pos_emb: torch.Tensor) -> torch.Tensor: + + assert_msg = ( + "Expect inputs to be 112x112 aligned. " + "Please align before sending image or use this version " + "of transformer that does the resizing/alignment automatically:" + "pip install " + "git+https://github.com/malkomes/transformers.git@e4269f72aebb00b82cc232866e6565597f6ceacf" + ) + assert x.shape[0]%64 == 0, assert_msg + hidden_states = x.unsqueeze(1) + for layer_num, blk in enumerate(self.blocks): + htcore.mark_step() + hidden_states = blk(hidden_states, + cu_seqlens=fullattn_mask if layer_num in self.fullatt_block_indexes else None, + rotary_pos_emb=rotary_pos_emb) + return hidden_states + + def post_attn(self, hidden_states: torch.Tensor, + window_index: torch.Tensor): + # adapter + hidden_states = self.merger(hidden_states) + reverse_indices = torch.argsort(window_index) + + hidden_states = hidden_states[reverse_indices, :] + return hidden_states + + def get_image_embeds( + self, + pixel_values: torch.Tensor, + grid_thw: torch.Tensor, + vision_buckets, + ) -> torch.Tensor: + + assert pixel_values.shape[0] % 64 == 0, ( + f"We need image h/w to be aligned to 112 for now. " + f"Which will make pixel_values be a multiple of (112/14)*(112/14)=64" + f"(14 is patch size for ViT). " + f"Got pixel_values shape {pixel_values.shape[0]}" + ) + offset = 0 + results = [] + # process each image one by one + for img_idx in range(grid_thw.shape[0]): + img_shape = grid_thw[img_idx, :].unsqueeze(0) + curr_img_size = img_shape.prod() + + pixel_values_curr_img = pixel_values[offset : offset + curr_img_size, :] + + offset += curr_img_size + pixel_values_curr_img_padded, img_shape_padded = \ + self.pad_multimodal_data(pixel_values_curr_img, img_shape, vision_buckets=vision_buckets) + + pixel_values_curr_img_padded, rot_pos_emb, \ + cu_seqlens, _, window_index = self.pre_attn( + pixel_values_curr_img_padded, img_shape_padded) + + expanded_cu_seqlens = expand_to_max(cu_seqlens, 3) # either a single image, + # or a single image and its accompanying pad image, so only max expansion to 3 + + # Create full attention block mast before VisionTransformer to save memory/time + fullatt_block_attn_mask = create_block_diagonal_attention_mask_outerprod(cu_seqlens) + assert pixel_values_curr_img_padded.shape[0] == expanded_cu_seqlens[-1] == rot_pos_emb.shape[0] + + htcore.mark_step() + hidden_states = self.forward(pixel_values_curr_img_padded, + rotary_pos_emb=rot_pos_emb, + fullattn_mask=fullatt_block_attn_mask) + htcore.mark_step() + + image_embeds = self.post_attn(hidden_states, window_index) + # slice image_embeds to remove the padded parts + pad_index = img_shape_padded[0].prod() // self.spatial_merge_unit + results += [image_embeds[:pad_index , :]] + results_cat = torch.concat(results) + image_embeds = results_cat + return image_embeds class Qwen2_5_VLProcessingInfo(Qwen2VLProcessingInfo): @@ -907,7 +1056,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = config self.multimodal_config = multimodal_config - self.visual = Qwen2_5_VisionTransformer( + if is_hpu: + qwen2_5_visionTransformer = Qwen2_5_VisionTransformerStaticShape + else: + qwen2_5_visionTransformer = Qwen2_5_VisionTransformer + + self.visual = qwen2_5_visionTransformer( config.vision_config, norm_eps=getattr(config, "rms_norm_eps", 1e-6), quant_config=self._maybe_ignore_quant_config(quant_config), @@ -997,7 +1151,6 @@ def _parse_and_validate_video_input( video_grid_thw = kwargs.pop("video_grid_thw", None) second_per_grid_ts = kwargs.pop("second_per_grid_ts", None) - if pixel_values_videos is None and video_embeds is None: return None @@ -1028,36 +1181,6 @@ def _parse_and_validate_video_input( video_embeds=video_embeds, video_grid_thw=video_grid_thw) - def pad_multimodal_data(self, pixel_values, image_grid_thw): - assert pixel_values.shape[ - 0] % 64 == 0, '[testing version] needs 64 aligned resolution' - - desired_number_of_pixels = self.vision_buckets.get_multimodal_bucket(pixel_values.shape[0]) - padding_len = desired_number_of_pixels - pixel_values.shape[0] - if padding_len <= 0: - #breakpoint() - return pixel_values, image_grid_thw - - logger.info( - f"[MM_BUCKETING] Padding current number pixel {pixel_values.shape[0]} to {desired_number_of_pixels}" - ) - # needs to make sure padding_len is even - assert padding_len % 64 == 0, '[testing version] padding needs to be multiple of 64' - - constant_value = -100 - pixel_values = torch.cat([ - pixel_values, - torch.ones((padding_len, pixel_values.shape[1]), device=pixel_values.device) * constant_value - ]) - - image_grid_thw = torch.cat( - [image_grid_thw, - torch.tensor([[1, 8, padding_len // 8]], device=image_grid_thw.device)]) - - assert image_grid_thw.prod(-1).sum() == desired_number_of_pixels - return pixel_values, image_grid_thw - - def _process_image_input( self, image_input: Qwen2_5_VLImageInputs) -> tuple[torch.Tensor, ...]: @@ -1071,56 +1194,12 @@ def _process_image_input( pixel_values = image_input["pixel_values"].type(self.visual.dtype) if is_hpu: - ''' - go thru grid_thw - say grid_thw is 1,16,16 and 1,128,128 - say u have 2 buckets: 512 and 16384 - - slice pixel_values at 16*16 = 256 (1st img) and attach a new "image" to it, to pad it up to 512 - attach a new - ''' - - offset = 0 - # right now we do 1 img at a time, but if we have multiple small images we could pack them in together - # Like say if I have image = 224, 224, 6400, and my buckets are: 1024, 6400 - # instead of padding 224->1024 and 224->1024, we can pack both 224 into 1 and send it to 1024 - results = [] - # During warmup: self.model.visual_warmup_times isnt set, so we can do it 1-by-1 - # after warmup we need to check "visual_warmup_times" and we can batch based on that - # Note that sometimes we may recompile, in which case "_get_multimodal_bucket" will return a larger number - # however we will not have time for that larger size in "visual_warmup_times" - # so after that our policy will be: - # if size within original buckets attempt coalescing within original buckets - # if size is larger, only then use a already precompiled non-original bucket - for img_idx in range(grid_thw.shape[0]): - img_shape = grid_thw[img_idx, :].unsqueeze(0) - curr_img_size = img_shape.prod() - - pixel_values_curr_img = pixel_values[offset : offset + curr_img_size, :] - #breakpoint() - offset += curr_img_size - pixel_values_curr_img_padded, img_shape_padded = self.pad_multimodal_data(pixel_values_curr_img, img_shape) - - pixel_values_curr_img_padded, rot_pos_emb, cu_seqlens, cu_window_seqlens, window_index = self.visual.pre_attn( - pixel_values_curr_img_padded, img_shape_padded) - - assert pixel_values.shape[0] % 64 == 0, f"We need image h/w to be aligned to 112 for now. Which will make pixel_values be a multiple of (112/14)*(112/14)=64 (14 is patch size for ViT). Got pixel_values shape {pixel_values.shape[0]}" - - expanded_cu_seqlens = expand_to_max(cu_seqlens, 3) # either a single image, or a single image and its accompanying pad image, so only max expansion to 3 - #Create full attention block mast before VisionTransformer to save memory/time - #TODO cu_seqlens can be removed but keep it here for now - fullatt_block_attn_mask = create_block_diagonal_attention_mask_outerprod(cu_seqlens) - assert pixel_values_curr_img_padded.shape[0] == expanded_cu_seqlens[-1] == rot_pos_emb.shape[0] - htcore.mark_step() - hidden_states = self.visual(pixel_values_curr_img_padded, - rotary_pos_emb=rot_pos_emb, - cu_seqlens=expanded_cu_seqlens, - fullattn_mask=fullatt_block_attn_mask,) - htcore.mark_step() - image_embeds = self.visual.post_attn(hidden_states, window_index) - results += [image_embeds[:img_shape_padded[0].prod()//4, :]] # slice image_embeds to remove the padded parts. instead of hardcoding 4, maybe use config spatial merge etc - results_cat = torch.concat(results) - image_embeds = results_cat + assert isinstance(self.visual, Qwen2_5_VisionTransformerStaticShape) + image_embeds = self.visual.get_image_embeds( + pixel_values, + grid_thw=grid_thw, + vision_buckets=self.vision_buckets, + ) else: image_embeds = self.visual(pixel_values, grid_thw=grid_thw) @@ -1143,18 +1222,15 @@ def _process_video_input( pixel_values_videos = video_input["pixel_values_videos"].type( self.visual.dtype) - #Moved dynamic calculation to pre_attn, and post_attn and keep the visual() block to be static to include only VisionTransformer and VisionMerger. - pixel_values_videos, rot_pos_emb, cu_seqlens, cu_window_seqlens, window_index = self.visual.pre_attn( - pixel_values_videos, grid_thw) - expanded_cu_seqlens = expand_to_max(cu_seqlens, 10) - htcore.mark_step() # padding in expand_to_max is dynamic - hidden_states = self.visual(pixel_values_videos, - rotary_pos_emb=rot_pos_emb, - cu_seqlens=expanded_cu_seqlens,) - #cu_window_seqlens=cu_window_seqlens) - video_embeds = self.visual.post_attn(hidden_states, window_index) - #video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw) - video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw) + if is_hpu: + assert isinstance(self.visual, Qwen2_5_VisionTransformerStaticShape) + video_embeds = self.visual.get_image_embeds( + pixel_values_videos, + grid_thw=grid_thw, + vision_buckets=self.vision_buckets, + ) + else: + video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw) # Split concatenated embeddings for each video item. merge_size = self.visual.spatial_merge_size @@ -1220,6 +1296,7 @@ def get_input_embeddings_v0( image_input: Optional[tuple[torch.Tensor, ...]] = None, video_input: Optional[tuple[torch.Tensor, ...]] = None, ) -> torch.Tensor: + inputs_embeds = self.get_input_embeddings(input_ids) if image_input is not None: image_embeds = self._process_image_input(image_input) From bdd279ac0184f081b6853dd43f673b3540136924 Mon Sep 17 00:00:00 2001 From: Gustavo Malkomes Date: Tue, 22 Apr 2025 22:15:27 +0000 Subject: [PATCH 31/38] no need to change this --- vllm/model_executor/models/utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index e9598077a379..3e969415a842 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -373,7 +373,6 @@ def _merge_multimodal_embeddings( f"multimodal tokens to {num_expected_tokens} placeholders") inputs_embeds[is_multimodal] = flattened - return inputs_embeds @@ -598,12 +597,15 @@ def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int): def make_empty_intermediate_tensors( batch_size: int, + context_size: int, dtype: torch.dtype, device: torch.device, ) -> IntermediateTensors: return IntermediateTensors({ key: - torch.zeros((batch_size, hidden_size), dtype=dtype, device=device) + torch.zeros((batch_size, context_size, hidden_size), + dtype=dtype, + device=device) for key in keys }) From 21376343c211633715cce8f8c29259b42dcdcfc2 Mon Sep 17 00:00:00 2001 From: Gustavo Malkomes Date: Wed, 23 Apr 2025 16:11:30 -0500 Subject: [PATCH 32/38] Update vllm/model_executor/models/qwen2_5_vl.py Co-authored-by: Iman Gohari --- vllm/model_executor/models/qwen2_5_vl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 08f15cac85e6..d8ed09555957 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -355,7 +355,7 @@ def forward( elif self.attn_backend == _Backend.TORCH_SDPA and is_hpu: from habana_frameworks.torch.hpex.kernels import FusedSDPA - # We are abusing the variable name cu_seqlens + # We are re-purposing the variable name cu_seqlens # to represent the mask for full attention, # if the mask if None we are doing window attention fullattn_mask = cu_seqlens From 14908e13467b923d936e0b1f91931e777c27827d Mon Sep 17 00:00:00 2001 From: Gustavo Malkomes Date: Wed, 23 Apr 2025 16:14:38 -0500 Subject: [PATCH 33/38] Update vllm/worker/hpu_model_runner.py Co-authored-by: Iman Gohari --- vllm/worker/hpu_model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 2538a8ac2ed0..112c35fc7144 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -1175,7 +1175,7 @@ def move_to_device(self, tensor): non_blocking=True) ''' - Right now Qwen2.5VL needs to know these buckets so it can do some things internally + Qwen2.5VL requires the bucket's information ''' def add_vision_buckets_to_model(self): From c0d020730676ddec7db7121f12a3657c3fdcb1e7 Mon Sep 17 00:00:00 2001 From: Gustavo Malkomes Date: Wed, 23 Apr 2025 21:25:21 +0000 Subject: [PATCH 34/38] working on comments --- vllm/model_executor/models/qwen2_5_vl.py | 15 ++++----------- vllm/worker/hpu_model_runner.py | 9 --------- 2 files changed, 4 insertions(+), 20 deletions(-) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index d8ed09555957..60bca871895b 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -88,10 +88,6 @@ def forward(q, k, v, mask, q_block_size): q_len = q.size(-2) assert q_len % q_block_size == 0 q_tiles = (q_len // q_block_size) if (q_len % q_block_size == 0) else math.ceil(q_len / q_block_size) - #q_padding = q_tiles * q_block_size - q_len - #q = F.pad(q, (0, 0, 0, q_padding), "constant", 0) - #if mask is not None: - # mask = F.pad(mask, (0, 0, 0, q_padding), "constant", -10000.0) attn_output = torch.zeros_like(q) for i in range(q_tiles): @@ -99,8 +95,7 @@ def forward(q, k, v, mask, q_block_size): row_q = q[:, :, s:e, :] row_mask = mask[:, :, s:e, :] attn_output[:, :, s:e, :] = FusedSDPA.apply(row_q, k, v, row_mask, 0.0, False, None) - #TODO: markstep every 10th layer, didn't experiment which one is optimal number. - #10,50,100 shows simliar result, without this, we see the program hangs for multiple prompts(with larger images) + #TODO: markstep every so often, didn't experiment which one is optimal number. if i % 75 == 0: htcore.mark_step() return attn_output @@ -112,10 +107,8 @@ def create_block_diagonal_attention_mask_outerprod(indices): greater_eq = range_to_max_for_each_img >= indices[:-1].unsqueeze(1) range_indices = torch.logical_and(lesser, greater_eq).float() # can reduce sum externally or as batchmatmul - #TODO: einsum with tensor dimension too big doesn't work. Register max size error. - #We can always move to CPU for all einsum without shape checking if perf impact is minimal. if range_indices.shape[-1] > 40000: - print("einsum running on CPU : ", range_indices.shape) + logger.info("einsum running on CPU : ", range_indices.shape) range_indices = range_indices.to("cpu") res = torch.einsum('bi,bj->ij', range_indices, range_indices) res = res.to("hpu") @@ -798,7 +791,7 @@ class Qwen2_5_VisionTransformerStaticShape(Qwen2_5_VisionTransformer): def pad_multimodal_data(self, pixel_values, image_grid_thw, vision_buckets): assert pixel_values.shape[ - 0] % 64 == 0, '[testing version] needs 64 aligned resolution' + 0] % 64 == 0, 'needs 64 aligned resolution' desired_number_of_pixels = vision_buckets.get_multimodal_bucket(pixel_values.shape[0]) padding_len = desired_number_of_pixels - pixel_values.shape[0] @@ -1309,7 +1302,7 @@ def get_input_embeddings_v0( if video_input is not None: if is_hpu: - print("Video inputs have not been enabled/verified yet, ignoring video inputs") + logger.warning("Video inputs have not been enabled/verified yet, ignoring video inputs") return inputs_embeds video_embeds = self._process_video_input(video_input) inputs_embeds = merge_multimodal_embeddings( diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 112c35fc7144..4c3ebddd25da 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -96,8 +96,6 @@ class VisionBuckets(): def __init__(self): envvar = os.environ.get('VLLM_MULTIMODAL_BUCKETS', "") if envvar == "": - #TODO:with profile_run, the bucket of 65536 is added, so the pixel values - #bigger than 12800 below always padded to 65536 which is too big. #self.multimodal_buckets = [1600, 3136, 4096, 6400, 7744, 9216, 12544, 16384, 26500, 40000, 65536] multimodal_buckets = [1600, 3136, 4096, 6400, 7744, 9216, 12544] else: @@ -108,8 +106,6 @@ def get_multimodal_bucket(self, curr_num_image_patches): for mm_bucket in self.multimodal_buckets: if curr_num_image_patches <= mm_bucket: return mm_bucket - #Remove dynamic bucket expands since this is not done for the language model. - #self.multimodal_buckets += [curr_num_image_patches] # a shape larger than any that was compiled before. its gonna be compiled now, so save it for the future return curr_num_image_patches def __repr__(self): @@ -2060,11 +2056,6 @@ def create_dummy_multi_modal_seq_group_metadata(self, group_id, lora_request): assert self.model_is_mrope, "Warmup compatible with Qwen2vl models" if num_patches == _UNSET_NUM_PATCHES: - # # only half of the total number of tokens should be from image - # num_image_tokens = self.model_config.max_model_len // 2 - # # get the number of patches from the num of image tokens - # num_patches = num_image_tokens * 4 - # Using the largest bucket num_patches = self.get_model().vision_buckets.multimodal_buckets[-1] From 96189f913d357bae884c4a8b8af95940f69e7d5d Mon Sep 17 00:00:00 2001 From: Gustavo Malkomes Date: Wed, 23 Apr 2025 22:26:40 +0000 Subject: [PATCH 35/38] buckets needs to be multiples of 8 --- vllm/worker/hpu_model_runner.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 4c3ebddd25da..2fcdc917d961 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -52,7 +52,6 @@ from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, MultiModalKwargs, MultiModalPlaceholderMap, MultiModalRegistry) -from vllm.multimodal.utils import cached_get_tokenizer from vllm.sampling_params import SamplingParams from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, Logprob, SequenceData, SequenceGroupMetadata, @@ -96,11 +95,15 @@ class VisionBuckets(): def __init__(self): envvar = os.environ.get('VLLM_MULTIMODAL_BUCKETS', "") if envvar == "": - #self.multimodal_buckets = [1600, 3136, 4096, 6400, 7744, 9216, 12544, 16384, 26500, 40000, 65536] multimodal_buckets = [1600, 3136, 4096, 6400, 7744, 9216, 12544] else: multimodal_buckets = [int(i) for i in envvar.split(',')] - self.multimodal_buckets = sorted(multimodal_buckets) + self.multimodal_buckets = self._process_buckets(multimodal_buckets) + + def _process_buckets(self, buckets): + for bucket in buckets: + assert bucket % 8 == 0, 'Buckets needs to be multiples 8 (slices of 64)' + return sorted(buckets) def get_multimodal_bucket(self, curr_num_image_patches): for mm_bucket in self.multimodal_buckets: @@ -2075,9 +2078,17 @@ def create_dummy_multi_modal_seq_group_metadata(self, group_id, seq_data = SequenceData.from_seqs(prompt_token_ids) seq_data = SequenceData(prompt_token_ids_array) - image_h = int(math.sqrt(num_patches)) - image_grid_thw = torch.tensor([1, image_h, image_h]) + assert num_patches % 8, "Expect num_patches to be multiples of 8" + image_h = num_patches // 8 + image_grid_thw = torch.tensor([1, image_h, 8]) + + image_grid_thw = torch.tensor([1, image_h, int(num_patches/image_h)]) pixel_values = torch.randn(image_grid_thw.prod(), 1176) # TODO: figure out the variable name + + assert pixel_values.shape[0] % 64 == 0, ( + f"pixel_values must be sliced in 64 chunks, got: {pixel_values.shape}" + ) + multi_modal_data = { "pixel_values": pixel_values, "image_grid_thw": image_grid_thw, From 7c9cf4ca760a2f13f80858efd230fda30030bb1e Mon Sep 17 00:00:00 2001 From: Gustavo Malkomes Date: Wed, 23 Apr 2025 22:35:09 +0000 Subject: [PATCH 36/38] ops --- vllm/model_executor/models/qwen2_5_vl.py | 2 +- vllm/worker/hpu_model_runner.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 60bca871895b..c11f15ef14d9 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -870,7 +870,7 @@ def forward( "pip install " "git+https://github.com/malkomes/transformers.git@e4269f72aebb00b82cc232866e6565597f6ceacf" ) - assert x.shape[0]%64 == 0, assert_msg + assert x.shape[0] % 64 == 0, assert_msg hidden_states = x.unsqueeze(1) for layer_num, blk in enumerate(self.blocks): htcore.mark_step() diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 2fcdc917d961..217d23aedd65 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -2078,7 +2078,9 @@ def create_dummy_multi_modal_seq_group_metadata(self, group_id, seq_data = SequenceData.from_seqs(prompt_token_ids) seq_data = SequenceData(prompt_token_ids_array) - assert num_patches % 8, "Expect num_patches to be multiples of 8" + assert num_patches % 8 == 0, ( + f"Expects num_patches to be multiples of 8, got: {num_patches}" + ) image_h = num_patches // 8 image_grid_thw = torch.tensor([1, image_h, 8]) From 32d7855e3e42ce2d040b6e961882556ecf4d1d24 Mon Sep 17 00:00:00 2001 From: Jimin Ha Date: Thu, 24 Apr 2025 00:01:00 +0000 Subject: [PATCH 37/38] Fix the multimodal warmup memory calculation --- vllm/worker/hpu_model_runner.py | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 217d23aedd65..8508822a70d3 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -2382,6 +2382,7 @@ def warmup_graphs(self, starting_mem=0, total_batch_seq=0.001): + if is_prompt and supports_multimodal(self.get_model()) and is_prompt: multimodal_prompt_graph_mem_ratio = float( os.environ.get('VLLM_GRAPH_MULTIMODAL_PROMPT_RATIO', '0.3')) @@ -2393,8 +2394,6 @@ def warmup_graphs(self, f"{format_bytes(available_mem)} for text prompt " f"(VLLM_GRAPH_MULTIMODAL_PROMPT_RATIO={multimodal_prompt_graph_mem_ratio})") logger.info(msg) - else: - multimodal_avail_mem = 0 total_mem = starting_mem idx = 0 @@ -2440,17 +2439,23 @@ def warmup_graphs(self, total_mem += used_mem total_batch_seq += batch_seq - mm_outputs = \ - self._warmup_multimodal_graph( - kv_caches=kv_caches, - available_mem=multimodal_avail_mem, - starting_mem=0, - total_batch_seq=total_batch_seq, - ) + if is_prompt: + #For multimodal total_batch_seq and total_mem, we store it in the + #attribute for now. + mm_outputs = self._warmup_multimodal_graph( + kv_caches=kv_caches, + available_mem=multimodal_avail_mem, + starting_mem=0 if not hasattr(self, "mm_total_mem") else self.mm_total_mem, + total_batch_seq=0.001 if not hasattr(self, "mm_total_batch_seq") else self.mm_total_batch_seq + ) + + if mm_outputs is not None: + mm_total_mem, total_mm_batch_seq, mm_captured_all = mm_outputs + total_mem = total_mem + mm_total_mem + captured_all = captured_all and mm_captured_all + self.mm_total_mem = mm_total_mem + self.mm_total_batch_seq= total_mm_batch_seq - if mm_outputs is not None: - total_mem, total_batch_seq, mm_captured_all = mm_outputs - captured_all = captured_all and mm_captured_all return total_mem, total_batch_seq, captured_all def _warmup_multimodal_graph(self, From 9a0ff19f3672c6de32e928d6f94519df420b4576 Mon Sep 17 00:00:00 2001 From: Jimin Ha Date: Thu, 24 Apr 2025 04:10:51 +0000 Subject: [PATCH 38/38] Fixe print error for einsum --- vllm/model_executor/models/qwen2_5_vl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index c11f15ef14d9..2965da9879c7 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -108,7 +108,7 @@ def create_block_diagonal_attention_mask_outerprod(indices): range_indices = torch.logical_and(lesser, greater_eq).float() # can reduce sum externally or as batchmatmul if range_indices.shape[-1] > 40000: - logger.info("einsum running on CPU : ", range_indices.shape) + logger.info(f"einsum running on CPU : {range_indices.shape}") range_indices = range_indices.to("cpu") res = torch.einsum('bi,bj->ij', range_indices, range_indices) res = res.to("hpu")