Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
4a2a025
add mask for couple of images in the same request
libinta Jan 26, 2026
644454b
fix accuracy issue for attn_mask path
libinta Jan 26, 2026
7afab16
fix accuracy issue
libinta Jan 27, 2026
6875827
remove not needed code
libinta Jan 27, 2026
750fcbc
fix precommit
libinta Jan 27, 2026
d839cbd
fix precommit
libinta Jan 27, 2026
c6e16a7
fix precommit
libinta Jan 27, 2026
fe5e05f
precommit fix
libinta Jan 27, 2026
2dce3d8
precommit fix
libinta Jan 27, 2026
fd78ad6
precommit fix
libinta Jan 27, 2026
b79ae59
precommit fix
libinta Jan 27, 2026
1ecb518
precommit fix
libinta Jan 27, 2026
c6e09ce
fix precommit
libinta Jan 27, 2026
0f98033
Enable qwen3_vl_moe model
shepark Jan 27, 2026
47f790b
Fix device mismatch in mrope tensor assignment
shepark Jan 27, 2026
e9c2052
fix precommits
shepark Jan 27, 2026
c37d68f
Update qwen3_vl.py for create_block_diagonal_mask optimization
libinta Jan 28, 2026
189ee3f
Implement bucket corrector for Mamba chunk size - v0.14.1 (#885)
jbyczkow Jan 28, 2026
a8def02
Revert "skip HPU graphs for long prefills" (#850) (#888)
adobrzyn Jan 28, 2026
15adc95
Cherry-picks to enable Llama4 Maverick (#882)
rsmyrek Jan 28, 2026
e104533
cherry-pick chunked attention from #821 + 32k+ context window fix fro…
Luca-Calabria Jan 28, 2026
0996a3b
Removed Attn_mask for qwen3
slokesha Jan 28, 2026
5120bf1
rebase and change bucket impl
libinta Jan 29, 2026
bf8a5f7
simplify logic
libinta Jan 29, 2026
b3f46e4
Fix a shape mismatch in mrope position slicing (#894)
shepark Jan 29, 2026
c3bc45b
Hpu granite 4.0-h small implementation (#883)
jbyczkow Jan 29, 2026
0a74f1b
Fix MultiModalBudget error (#892)
adobrzyn Jan 29, 2026
7c7c81a
rebase
libinta Jan 29, 2026
28bd042
precommit fix
libinta Jan 29, 2026
73e1472
precommit fix
libinta Jan 29, 2026
2d99b02
precommit fix
libinta Jan 29, 2026
bfb638b
precommit fix
libinta Jan 29, 2026
c6668b2
precommit fix
libinta Jan 29, 2026
a7072bb
Merge branch 'releases/v0.14.1' into libinta/add_mask
slokesha Jan 29, 2026
83669dd
fix gemma issue
libinta Jan 29, 2026
6371ea4
fix gemma3 issue
libinta Jan 29, 2026
19aac64
fix precommit
libinta Jan 29, 2026
1b7d074
remove uncessary code
libinta Jan 29, 2026
8f7ab84
precommit fix
libinta Jan 29, 2026
1439d0b
fix precommit comment
libinta Jan 29, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 32 additions & 4 deletions vllm_gaudi/extension/bucketing/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
},
'qwen3_vl': {
'is_batch_based': False,
'buckets':
[256, 512, 1024, 1350, 1602, 2048, 3072, 4096, 5120, 6144, 7168, 8192, 9216, 10240, 11264, 12288, 131076]
# patches per image
'buckets': [196, 256, 441, 480, 576, 900, 1156]
}
}

Expand All @@ -37,6 +37,8 @@ def __init__(self, model_name, is_batch_based=None):

self.is_batch_based = is_batch_based if is_batch_based is not None else config['is_batch_based']

self.qwen2_5_vl = 'qwen2_5_vl' in model_name.lower()

envvar = os.environ.get('VLLM_MULTIMODAL_BUCKETS', "")

if envvar == 'None':
Expand Down Expand Up @@ -85,15 +87,16 @@ def find_factor(self, desired_patches, orig):
return None

def find_padding(self, h_orig, w_orig, desired_patches):
merge_size = 2
best_pad_h, best_pad_w = 0, 0
if desired_patches % h_orig == 0:
best_pad_h = 0
w_factor = desired_patches // h_orig
best_pad_w = w_factor - w_orig if (w_factor > w_orig and w_factor % 2 == 0) else 0
best_pad_w = w_factor - w_orig if (w_factor > w_orig and w_factor % merge_size == 0) else 0
elif desired_patches % w_orig == 0:
best_pad_w = 0
h_factor = desired_patches // w_orig
best_pad_h = h_factor - h_orig if (h_factor > h_orig and h_factor % 2 == 0) else 0
best_pad_h = h_factor - h_orig if (h_factor > h_orig and h_factor % merge_size == 0) else 0
elif desired_patches % h_orig != 0 and desired_patches % w_orig != 0:
if h_orig > w_orig:
w_factor = self.find_factor(desired_patches, w_orig)
Expand Down Expand Up @@ -163,3 +166,28 @@ def greedy_plan(self, batchsize, available_batchsizes):

def __repr__(self):
return str(self.multimodal_buckets)

def bucket_to_image_resolution(self, patch_size: int = 14):
"""
Calculate image resolution by first determining height from target_patches,
then deriving width from aspect ratio.
"""
aspect_ratios = [
(1, 1), # 1:1 square
(4, 3), # 4:3 landscape
(3, 4), # 3:4 portrait
(16, 9), # 16:9 widescreen
(9, 16), # 9:16 portrait
]
merge_size = 2 # Qwen2.5/3VL spatial_merge_size
resolution_list = []
for target_patches in self.multimodal_buckets:
for (ratio_w, ratio_h) in aspect_ratios:
grid_h = int(target_patches**0.5)
height = grid_h * patch_size
width = int(height * ratio_w / ratio_h)
grid_w = width // patch_size
if grid_w * grid_h // merge_size != 0:
grid_w = ((grid_w + merge_size - 1) // merge_size) * merge_size
resolution_list.append((grid_w * patch_size, height))
return resolution_list
4 changes: 4 additions & 0 deletions vllm_gaudi/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,7 @@ def register_model():
from vllm_gaudi.models.qwen3_vl import HpuQwen3_VLForConditionalGeneration # noqa: F401
ModelRegistry.register_model("Qwen3VLForConditionalGeneration",
"vllm_gaudi.models.qwen3_vl:HpuQwen3_VLForConditionalGeneration")

from vllm_gaudi.models.qwen3_vl_moe import HpuQwen3_VLMoeForConditionalGeneration # noqa: F401
ModelRegistry.register_model("Qwen3VLMoeForConditionalGeneration",
"vllm_gaudi.models.qwen3_vl_moe:HpuQwen3_VLMoeForConditionalGeneration")
43 changes: 21 additions & 22 deletions vllm_gaudi/models/qwen2_5_vl.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import math
import os
from functools import partial
from typing import Optional, Callable, Union
Expand Down Expand Up @@ -34,6 +33,7 @@
from vllm.model_executor.models.utils import (maybe_prefix, cast_overflow_tensors)

from vllm.multimodal.inputs import MultiModalFieldConfig
from vllm_gaudi.extension.runtime import get_config

import habana_frameworks.torch.core as htcore
from habana_frameworks.torch.hpex.kernels import FusedSDPA
Expand Down Expand Up @@ -72,28 +72,27 @@ class HPU_Attention:
in ['true', '1'] else 'None'

@classmethod
def forward(cls, q, k, v, mask, q_block_size=64):
def forward(cls, q, k, v, mask, cu_seqlens, qwen2_5_vl, q_block_size=64):
"""
Support long sequence at prompt phase
"""
q_len = q.size(-2)
if q_len <= 65536: # need to investigate this crosspoint
return FusedSDPA.apply(q, k, v, mask, 0.0, False, None, cls.softmax_mode)

assert q_len % q_block_size == 0
q_tiles = (q_len // q_block_size) if (q_len % q_block_size == 0) else math.ceil(q_len / q_block_size)
attn_output = torch.zeros_like(q)

for i in range(q_tiles):
s, e = i * q_block_size, (i + 1) * q_block_size
row_q = q[:, :, s:e, :]
row_mask = mask[:, :, s:e, :]
attn_output[:, :, s:e, :] = FusedSDPA.apply(row_q, k, v, row_mask, 0.0, False, None, cls.softmax_mode)
# TODO: markstep after a couple of iterations
# need to experiment the optimal number.
if i % 75 == 0:
htcore.mark_step()
return attn_output
lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
if mask is not None or len(lens) == 1:
if not qwen2_5_vl or (qwen2_5_vl and q_len < 65536):
return FusedSDPA.apply(q, k, v, mask, 0.0, False, None, cls.softmax_mode)
else:
return AttentionLongSequence.forward(q, k, v, mask, q_block_size, cls.softmax_mode)
else:
q_chunks = torch.split(q, lens, dim=2)
k_chunks = torch.split(k, lens, dim=2)
v_chunks = torch.split(v, lens, dim=2)
outputs = []
for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks):
output_i = FusedSDPA.apply(q_i, k_i, v_i, None, 0.0, False, None, cls.softmax_mode)
outputs.append(output_i)
context_layer = torch.cat(outputs, dim=2)
return context_layer


def create_block_diagonal_attention_mask(indices):
Expand Down Expand Up @@ -148,6 +147,8 @@ def __init__(
)

self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True)
model_type = get_config().model_type
self.qwen2_5_vl = 'qwen2_5_vl' in model_type.lower()

def forward(
self,
Expand Down Expand Up @@ -187,11 +188,9 @@ def forward(

# performs full attention using the previous computed mask
q1, k1, v1 = (rearrange(x, "b s h d -> b h s d") for x in [q, k, v])
output = HPU_Attention.forward(q1, k1, v1, attn_mask)
output = HPU_Attention.forward(q1, k1, v1, attn_mask, cu_seqlens, self.qwen2_5_vl)
context_layer = rearrange(output, "b h s d -> b s h d ")

context_layer = rearrange(context_layer, "b s h d -> s b (h d)").contiguous()

output, _ = self.proj(context_layer)
return output

Expand Down
48 changes: 48 additions & 0 deletions vllm_gaudi/models/qwen3_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import torch
from torch import nn

from vllm.model_executor.models.qwen3_moe import (
Qwen3MoeSparseMoeBlock as UpstreamQwen3MoeSparseMoeBlock, )
from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.distributed import tensor_model_parallel_all_gather


class HpuQwen3MoeSparseMoeBlock(UpstreamQwen3MoeSparseMoeBlock):

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
orig_shape = hidden_states.shape
hidden_dim = orig_shape[-1]

hs = hidden_states.reshape(-1, hidden_dim) # (T, H)
num_tokens = hs.shape[0]

if getattr(self, "is_sequence_parallel", False):
hs = sequence_parallel_chunk(hs)

router_logits, _ = self.gate(hs)
out = self.experts(hidden_states=hs, router_logits=router_logits)

if getattr(self, "is_sequence_parallel", False):
out = tensor_model_parallel_all_gather(out, 0)
out = out[:num_tokens]

return out.reshape(*orig_shape[:-1], hidden_dim)


def upgrade_qwen3_moe_blocks_inplace(language_model: nn.Module) -> int:
lm_model = getattr(language_model, "model", None)
layers = getattr(lm_model, "layers", None)
if layers is None:
return

for layer in layers:
mlp = getattr(layer, "mlp", None)
if mlp is None:
continue

if isinstance(mlp, HpuQwen3MoeSparseMoeBlock):
continue

if isinstance(mlp, UpstreamQwen3MoeSparseMoeBlock):
mlp.__class__ = HpuQwen3MoeSparseMoeBlock
mlp._hpu_accept_3d_installed = True
126 changes: 125 additions & 1 deletion vllm_gaudi/models/qwen3_vl.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
import torch
import numpy as np
from .utils import _merge_multimodal_embeddings
from vllm.config import VllmConfig
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
from vllm.model_executor.models.interfaces import _require_is_multimodal

from vllm.model_executor.models.qwen2_5_vl import (
Qwen2_5_VLImageInputs, )
from vllm.model_executor.models.qwen3_vl import (
Qwen3VLForConditionalGeneration,
Qwen3_VisionTransformer,
Qwen3_VisionBlock,
)
from vllm.model_executor.models.vision import run_dp_sharded_mrope_vision_model

from vllm.model_executor.models.utils import maybe_prefix

from vllm_gaudi.models.qwen2_5_vl import (HPUQwen2_5_VisionAttention)
from vllm_gaudi.models.qwen2_5_vl import HPUQwen2_5_VisionAttention


class HPUQwen3_VisionBlock(Qwen3_VisionBlock):
Expand Down Expand Up @@ -48,6 +54,27 @@ def __init__(
prefix=f"{prefix}.attn",
)

def forward(
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: torch.Tensor, # Only used for Flash Attention
attn_mask=None,
) -> torch.Tensor:
x = x + self.attn(
self.norm1(x),
cu_seqlens=cu_seqlens,
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
attn_mask=attn_mask,
max_seqlen=max_seqlen,
)

x = x + self.mlp(self.norm2(x))
return x


class HPUQwen3_VisionTransformer(Qwen3_VisionTransformer):

Expand Down Expand Up @@ -83,6 +110,51 @@ def __init__(
) for layer_idx in range(depth)
])

def forward(
self,
x: torch.Tensor,
grid_thw: torch.Tensor | list[list[int]],
attn_mask: torch.Tensor | None = None,
) -> torch.Tensor:
hidden_states = x.to(device=self.device, dtype=self.dtype, non_blocking=True)
hidden_states = self.patch_embed(hidden_states)

if isinstance(grid_thw, list):
grid_thw_list = grid_thw
grid_thw = np.array(grid_thw, dtype=np.int32)
else:
grid_thw_list = grid_thw.tolist()
grid_thw = grid_thw.numpy()

pos_embeds = self.fast_pos_embed_interpolate(grid_thw_list)
hidden_states = hidden_states + pos_embeds
rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw_list)

cu_seqlens = np.repeat(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(axis=0, dtype=np.int32)
cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens])
cu_seqlens = torch.from_numpy(cu_seqlens)
hidden_states = hidden_states.unsqueeze(1)
max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
deepstack_feature_lists = []
for layer_num, blk in enumerate(self.blocks):
hidden_states = blk(
hidden_states,
cu_seqlens=cu_seqlens,
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
attn_mask=attn_mask,
)
if layer_num in self.deepstack_visual_indexes:
deepstack_merger_idx = self.deepstack_visual_indexes.index(layer_num)
deepstack_feature = self.deepstack_merger_list[deepstack_merger_idx](hidden_states)
deepstack_feature_lists.append(deepstack_feature)
hidden_states = self.merger(hidden_states)
hidden_states = torch.cat([hidden_states] + deepstack_feature_lists,
dim=1) # [seq_len, hidden_size * (1 + depth_of_deepstack)]
return hidden_states


class HpuQwen3_VLForConditionalGeneration(Qwen3VLForConditionalGeneration):

Expand All @@ -101,6 +173,58 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
prefix=maybe_prefix(prefix, "visual"),
)

def create_block_diagonal_mask(self,
cu_seqlens: torch.Tensor,
grid_thw: list[int],
device: torch.device = None,
dtype: torch.dtype = torch.bool) -> torch.Tensor:
"""
Create block diagonal mask that excludes padded tokens for Qwen3VL attention.
Args:
cu_seqlens: Cumulative sequence lengths from grid dimensions
grid_thw: The grid dimensions with merge_size=2 compatibility
device: Target device for the mask
dtype: Data type for the mask (typically torch.bool)

Returns:
Block diagonal attention mask with shape [total_seq_len, total_seq_len]
"""
if device is None:
device = cu_seqlens.device

# Calculate total sequence length including padding
total_patches = int(grid_thw.prod(-1).sum().item())
# Create mask with total size including padding
mask = torch.zeros(total_patches, total_patches, device=device, dtype=dtype)
cu_seqlens = cu_seqlens.tolist()
cu_seqlens = [0] + cu_seqlens
starts = cu_seqlens[:-1]
ends = cu_seqlens[1:]
for start, end in zip(starts, ends):
mask[start:end, start:end] = True
return mask

def _process_image_input(self, image_input: Qwen2_5_VLImageInputs) -> tuple[torch.Tensor, ...]:
grid_thw = image_input["image_grid_thw"]
assert grid_thw.ndim == 2

if image_input["type"] == "image_embeds":
image_embeds = image_input["image_embeds"].type(self.visual.dtype)
else:
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
if self.use_data_parallel:
return run_dp_sharded_mrope_vision_model(self.visual,
pixel_values,
grid_thw.tolist(),
rope_type="rope_3d")
else:
image_embeds = self.visual(pixel_values, grid_thw=grid_thw, attn_mask=None)

# Split concatenated embeddings for each image item.
merge_size = self.visual.spatial_merge_size
sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
return image_embeds.split(sizes)

def _compute_deepstack_embeds(
self,
inputs_embeds: torch.Tensor,
Expand Down
Loading