Skip to content
Merged
36 changes: 32 additions & 4 deletions vllm_gaudi/extension/bucketing/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
},
'qwen3_vl': {
'is_batch_based': False,
'buckets':
[256, 512, 1024, 1350, 1602, 2048, 3072, 4096, 5120, 6144, 7168, 8192, 9216, 10240, 11264, 12288, 131076]
# patches per image
'buckets': [196, 256, 441, 480, 576, 900, 1156]
}
}

Expand All @@ -37,6 +37,8 @@ def __init__(self, model_name, is_batch_based=None):

self.is_batch_based = is_batch_based if is_batch_based is not None else config['is_batch_based']

self.qwen2_5_vl = 'qwen2_5_vl' in model_name.lower()

envvar = os.environ.get('VLLM_MULTIMODAL_BUCKETS', "")

if envvar == 'None':
Expand Down Expand Up @@ -85,15 +87,16 @@ def find_factor(self, desired_patches, orig):
return None

def find_padding(self, h_orig, w_orig, desired_patches):
merge_size = 2
best_pad_h, best_pad_w = 0, 0
if desired_patches % h_orig == 0:
best_pad_h = 0
w_factor = desired_patches // h_orig
best_pad_w = w_factor - w_orig if (w_factor > w_orig and w_factor % 2 == 0) else 0
best_pad_w = w_factor - w_orig if (w_factor > w_orig and w_factor % merge_size == 0) else 0
elif desired_patches % w_orig == 0:
best_pad_w = 0
h_factor = desired_patches // w_orig
best_pad_h = h_factor - h_orig if (h_factor > h_orig and h_factor % 2 == 0) else 0
best_pad_h = h_factor - h_orig if (h_factor > h_orig and h_factor % merge_size == 0) else 0
elif desired_patches % h_orig != 0 and desired_patches % w_orig != 0:
if h_orig > w_orig:
w_factor = self.find_factor(desired_patches, w_orig)
Expand Down Expand Up @@ -163,3 +166,28 @@ def greedy_plan(self, batchsize, available_batchsizes):

def __repr__(self):
return str(self.multimodal_buckets)

def bucket_to_image_resolution(self, patch_size: int = 14):
"""
Calculate image resolution by first determining height from target_patches,
then deriving width from aspect ratio.
"""
aspect_ratios = [
(1, 1), # 1:1 square
(4, 3), # 4:3 landscape
(3, 4), # 3:4 portrait
(16, 9), # 16:9 widescreen
(9, 16), # 9:16 portrait
]
merge_size = 2 # Qwen2.5/3VL spatial_merge_size
resolution_list = []
for target_patches in self.multimodal_buckets:
for (ratio_w, ratio_h) in aspect_ratios:
grid_h = int(target_patches**0.5)
height = grid_h * patch_size
width = int(height * ratio_w / ratio_h)
grid_w = width // patch_size
if grid_w * grid_h // merge_size != 0:
grid_w = ((grid_w + merge_size - 1) // merge_size) * merge_size
resolution_list.append((grid_w * patch_size, height))
return resolution_list
4 changes: 4 additions & 0 deletions vllm_gaudi/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,7 @@ def register_model():
from vllm_gaudi.models.qwen3_vl import HpuQwen3_VLForConditionalGeneration # noqa: F401
ModelRegistry.register_model("Qwen3VLForConditionalGeneration",
"vllm_gaudi.models.qwen3_vl:HpuQwen3_VLForConditionalGeneration")

from vllm_gaudi.models.qwen3_vl_moe import HpuQwen3_VLMoeForConditionalGeneration # noqa: F401
ModelRegistry.register_model("Qwen3VLMoeForConditionalGeneration",
"vllm_gaudi.models.qwen3_vl_moe:HpuQwen3_VLMoeForConditionalGeneration")
59 changes: 31 additions & 28 deletions vllm_gaudi/models/qwen2_5_vl.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import math
import os
from functools import partial
from typing import Optional, Callable, Union
Expand Down Expand Up @@ -34,6 +33,7 @@
from vllm.model_executor.models.utils import (maybe_prefix, cast_overflow_tensors)

from vllm.multimodal.inputs import MultiModalFieldConfig
from vllm_gaudi.extension.runtime import get_config

import habana_frameworks.torch.core as htcore
from habana_frameworks.torch.hpex.kernels import FusedSDPA
Expand Down Expand Up @@ -72,28 +72,30 @@ class HPU_Attention:
in ['true', '1'] else 'None'

@classmethod
def forward(cls, q, k, v, mask, q_block_size=64):
def forward(cls, q, k, v, mask, cu_seqlens, qwen2_5_vl, q_block_size=64):
"""
Support long sequence at prompt phase
"""
q_len = q.size(-2)
if q_len <= 65536: # need to investigate this crosspoint
return FusedSDPA.apply(q, k, v, mask, 0.0, False, None, cls.softmax_mode)

assert q_len % q_block_size == 0
q_tiles = (q_len // q_block_size) if (q_len % q_block_size == 0) else math.ceil(q_len / q_block_size)
attn_output = torch.zeros_like(q)

for i in range(q_tiles):
s, e = i * q_block_size, (i + 1) * q_block_size
row_q = q[:, :, s:e, :]
row_mask = mask[:, :, s:e, :]
attn_output[:, :, s:e, :] = FusedSDPA.apply(row_q, k, v, row_mask, 0.0, False, None, cls.softmax_mode)
# TODO: markstep after a couple of iterations
# need to experiment the optimal number.
if i % 75 == 0:
htcore.mark_step()
return attn_output
if qwen2_5_vl:
if q_len < 65536:
return FusedSDPA.apply(q, k, v, mask, 0.0, False, None, cls.softmax_mode)
else:
return AttentionLongSequence.forward(q, k, v, mask, q_block_size, cls.softmax_mode)
else:
lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
if len(lens) == 1:
return FusedSDPA.apply(q, k, v, None, 0.0, False, None, cls.softmax_mode)
else:
q_chunks = torch.split(q, lens, dim=2)
k_chunks = torch.split(k, lens, dim=2)
v_chunks = torch.split(v, lens, dim=2)
outputs = []
for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks):
output_i = FusedSDPA.apply(q_i, k_i, v_i, None, 0.0, False, None, cls.softmax_mode)
outputs.append(output_i)
context_layer = torch.cat(outputs, dim=2)
return context_layer


def create_block_diagonal_attention_mask(indices):
Expand Down Expand Up @@ -146,6 +148,8 @@ def __init__(
)

self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True)
model_type = get_config().model_type
self.qwen2_5_vl = 'qwen2_5_vl' in model_type.lower()

def forward(
self,
Expand Down Expand Up @@ -185,11 +189,9 @@ def forward(

# performs full attention using the previous computed mask
q1, k1, v1 = (rearrange(x, "b s h d -> b h s d") for x in [q, k, v])
output = HPU_Attention.forward(q1, k1, v1, attn_mask)
output = HPU_Attention.forward(q1, k1, v1, attn_mask, cu_seqlens, self.qwen2_5_vl)
context_layer = rearrange(output, "b h s d -> b s h d ")

context_layer = rearrange(context_layer, "b s h d -> s b (h d)").contiguous()

output, _ = self.proj(context_layer)
return output

Expand Down Expand Up @@ -233,13 +235,11 @@ def forward(
seqlens: Optional[list[int]] = None, # Only used for xFormers
attn_mask: Optional[torch.Tensor] = None, # Only used for HPU
) -> torch.Tensor:
mask_to_use = attn_mask if attn_mask is not None else cu_seqlens

x = x + self.attn(self.norm1(x),
cu_seqlens=cu_seqlens,
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
attn_mask=mask_to_use)
attn_mask=attn_mask)

x = x + self.mlp(self.norm2(x))
return x
Expand Down Expand Up @@ -356,7 +356,8 @@ def pre_attn(self, x: torch.Tensor, grid_thw: torch.Tensor):
)

def forward(self, hidden_states: torch.Tensor, rotary_pos_emb_cos: torch.Tensor, rotary_pos_emb_sin: torch.Tensor,
padding_attn_mask_window: torch.Tensor, padding_attn_mask_full: torch.Tensor) -> torch.Tensor:
padding_attn_mask_window: torch.Tensor, padding_attn_mask_full: torch.Tensor,
cu_seqlens: torch.Tensor) -> torch.Tensor:
hidden_states = hidden_states.unsqueeze(1)
for layer_num, blk in enumerate(self.blocks):
if layer_num in self.fullatt_block_indexes:
Expand All @@ -366,9 +367,10 @@ def forward(self, hidden_states: torch.Tensor, rotary_pos_emb_cos: torch.Tensor,

hidden_states = blk(
hidden_states,
cu_seqlens=padding_attn_mask_now,
cu_seqlens=cu_seqlens,
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
attn_mask=padding_attn_mask_now,
)

# For Qwen2.5-VL-3B, float16 will overflow at last block
Expand Down Expand Up @@ -436,7 +438,8 @@ def get_image_embeds(
rotary_pos_emb_cos=rot_pos_emb_cos,
rotary_pos_emb_sin=rot_pos_emb_sin,
padding_attn_mask_window=padding_attn_mask_window,
padding_attn_mask_full=padding_attn_mask_full)
padding_attn_mask_full=padding_attn_mask_full,
cu_seqlens=cu_seqlens)
htcore.mark_step()

# remove padding
Expand Down
48 changes: 48 additions & 0 deletions vllm_gaudi/models/qwen3_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import torch
from torch import nn

from vllm.model_executor.models.qwen3_moe import (
Qwen3MoeSparseMoeBlock as UpstreamQwen3MoeSparseMoeBlock, )
from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.distributed import tensor_model_parallel_all_gather


class HpuQwen3MoeSparseMoeBlock(UpstreamQwen3MoeSparseMoeBlock):

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
orig_shape = hidden_states.shape
hidden_dim = orig_shape[-1]

hs = hidden_states.reshape(-1, hidden_dim) # (T, H)
num_tokens = hs.shape[0]

if getattr(self, "is_sequence_parallel", False):
hs = sequence_parallel_chunk(hs)

router_logits, _ = self.gate(hs)
out = self.experts(hidden_states=hs, router_logits=router_logits)

if getattr(self, "is_sequence_parallel", False):
out = tensor_model_parallel_all_gather(out, 0)
out = out[:num_tokens]

return out.reshape(*orig_shape[:-1], hidden_dim)


def upgrade_qwen3_moe_blocks_inplace(language_model: nn.Module) -> int:
lm_model = getattr(language_model, "model", None)
layers = getattr(lm_model, "layers", None)
if layers is None:
return

for layer in layers:
mlp = getattr(layer, "mlp", None)
if mlp is None:
continue

if isinstance(mlp, HpuQwen3MoeSparseMoeBlock):
continue

if isinstance(mlp, UpstreamQwen3MoeSparseMoeBlock):
mlp.__class__ = HpuQwen3MoeSparseMoeBlock
mlp._hpu_accept_3d_installed = True
Loading