Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README_GAUDI.md
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,7 @@ batch size is often at its maximum, making large-batch HPU graphs critical to ca
- `VLLM_HANDLE_TOPK_DUPLICATES`: if ``true`` - handles duplicates outside top-k. The default is `false`.
- `VLLM_CONFIG_HIDDEN_LAYERS`: configures how many hidden layers to run in a HPUGraph for model splitting among hidden layers when TP is 1. It helps to improve throughput by reducing inter-token latency limitations in some models. The default is `1`.
- `VLLM_SKIP_WARMUP`: if `true`, warm-up is skipped. The default is `false`.
- `VLLM_FUSEDSDPA_SLIDE_RIGHT`: right sliding window size when fusedsdpa used with sliding window. It helps with memory and performance when long context is used. The default is `0`. Example: for sliding window of size 1024, set VLLM_FUSEDSDPA_SLIDE_RIGHT=1024

> [!TIP]
> When a deployed workload does not utilize the full context that a model can handle, it is good practice to limit the maximum values upfront based on the input and output token lengths that will be generated after serving the vLLM server.
Expand Down
2 changes: 1 addition & 1 deletion requirements/hpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ ray
triton==3.1.0
setuptools>=77.0.3
setuptools-scm>=8
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@5135570
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@009adb2

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR is almost identical to #1660. Please apply comments from there

5 changes: 4 additions & 1 deletion vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata):
window_block_usage: Optional[torch.Tensor] = None
window_attn_bias: Optional[torch.Tensor] = None
use_window_sdpa: Optional[bool] = None
sliding_window_right: Optional[int] = None


@dataclass
Expand Down Expand Up @@ -405,6 +406,7 @@ def __init__(

self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.sliding_window = sliding_window

self.prompt_position_bias = None
self.prev_attn = None
self.alibi_slopes = None
Expand Down Expand Up @@ -549,7 +551,8 @@ def forward(

if attn_metadata.use_window_sdpa:
attn_bias = attn_metadata.attn_bias
window_size = (self.sliding_window, 0)
window_size = (self.sliding_window,
attn_metadata.sliding_window_right)
common_args['window_size'] = window_size
# TODO: Currently HPU doesn't support GQA for FusedSDPA
# with causal + window, so repeat KV so QKV are all the
Expand Down
81 changes: 56 additions & 25 deletions vllm/model_executor/models/gemma3_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,42 @@ def forward(self,

return hidden_states

def hpu_build_mask(self, input_ids: torch.Tensor,
mask_dtype: torch.dtype) -> torch.Tensor:
bs, seq_len = input_ids.shape
device = input_ids.device
img_tokens = self.config.mm_tokens_per_image
image_token_index = self.config.image_token_index
# bool causal mask (True == masked)
causal_bool = torch.triu(
torch.ones(seq_len, seq_len, dtype=torch.bool, device=device), 1)
mask_bool = causal_bool.unsqueeze(0).unsqueeze(0).expand(
bs, 1, -1, -1).clone()

# pre-compute a few broadcastable helpers
img_pos = (input_ids == image_token_index) # [B,S]
img_row = img_pos.unsqueeze(1).unsqueeze(3) # [B,1,S,1]
img_col = img_pos.unsqueeze(1).unsqueeze(2) # [B,1,1,S]

img_pos_cum = torch.cumsum(img_pos, 1)
img_causal = torch.arange(seq_len, device=device).unsqueeze(0) \
- img_pos_cum + (img_pos_cum // img_tokens + 1) * img_tokens + 1
img_causal = torch.cat((img_causal[:, :1] - 1, img_causal[:, :-1]), 1) \
.clamp_(0, seq_len - 1) \
.unsqueeze(1).unsqueeze(3) # [B,1,S,1]
ind = torch.arange(seq_len, device=device).view(1, 1, 1,
-1) # [1,1,1,S]

# positions we must *unmask* (row img ∧ col img
# ∧ col < img_causal)
allow = img_row & img_col & (ind < img_causal)
mask_bool &= ~allow # flip to False
# 4) final bfp16/32 version
out = torch.zeros_like(mask_bool, dtype=mask_dtype) \
.masked_fill(mask_bool, float("-inf"))

return out

def prepare_attn_masks(
self,
input_ids: torch.Tensor,
Expand Down Expand Up @@ -696,40 +732,35 @@ def prepare_attn_masks(
local_attn_masks = []
start_idx = 0
for seq_len in seq_lens:
if not is_hpu:
if is_hpu:
global_attn_mask = self.hpu_build_mask(input_ids, mask_dtype)
else:
end_idx = start_idx + seq_len
input_token_ids = input_ids[start_idx:end_idx]
start_idx = end_idx
bs = 1
else:
input_token_ids = input_ids
# Create a global causal mask.
global_attn_mask = torch.empty(
bs,
1,
seq_len,
seq_len,
dtype=mask_dtype,
device=input_ids.device,
)
global_attn_mask.fill_(float("-inf"))
# Fill the lower triangle with 0.
global_attn_mask = global_attn_mask.triu(diagonal=1)
# Create a global causal mask.
global_attn_mask = torch.empty(
bs,
1,
seq_len,
seq_len,
dtype=mask_dtype,
device=input_ids.device,
)
global_attn_mask.fill_(float("-inf"))
# Fill the lower triangle with 0.
global_attn_mask = global_attn_mask.triu(diagonal=1)

# Consider the bidirectional attention between image tokens.
img_mask = torch.zeros_like(global_attn_mask)
img_pos = (input_token_ids == self.config.image_token_index)
# Consider the bidirectional attention between image tokens.
img_mask = torch.zeros_like(global_attn_mask)
img_pos = (input_token_ids == self.config.image_token_index)

if not is_hpu:
img_mask[:, :, :, img_pos] += 1
img_mask[:, :, img_pos, :] += 1
else:
img_mask[img_pos.unsqueeze(1)] += 1
img_mask = img_mask.permute(0, 1, 3, 2)
img_mask[img_pos.unsqueeze(1)] += 1
img_mask = img_mask.permute(0, 1, 3, 2)
global_attn_mask = torch.where(img_mask == 2, 0,
global_attn_mask)

global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask)
global_attn_masks.append(global_attn_mask)

if self.sliding_window is not None:
Expand Down
26 changes: 12 additions & 14 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,13 +352,20 @@ def __init__(self, model, vllm_config, is_causal, sampler):
self.use_window_sdpa = os.getenv("PT_HPU_SDPA_QKV_SLICE_MODE_FWD",
"false").strip().lower() in ("1",
"true")
self.sliding_window_right = 0
if self.use_window_sdpa:
self.slice_size = int(
os.getenv("PT_HPU_QKV_SLICE_SEQ_LEN_THLD", "1024"))

os.environ["PT_HPU_SDPA_BC_FACTOR"] = str(self.slice_size)
os.environ["PT_HPU_SDPA_BR_FACTOR"] = str(self.slice_size)
os.environ["PT_HPU_QKV_SLICE_SEQ_LEN_THLD"] = str(self.slice_size)
self.sliding_window_right = int(
os.environ.get('VLLM_FUSEDSDPA_SLIDE_RIGHT', '0'))
assert self.sliding_window_right % self.slice_size == 0, \
f'VLLM_FUSEDSDPA_SLIDE_RIGHT({self.sliding_window_right}) '\
f'not supported due to not a multiplier of '\
f'PT_HPU_QKV_SLICE_SEQ_LEN_THLD({self.slice_size})!'

# This applies exclusively to Qwen2/2.5-VL models
# both use mrope. We wrap the visual and language
Expand Down Expand Up @@ -580,6 +587,8 @@ def _update_use_window_sdpa(self, attn_metadata, seq_len):
f"VLLM_PROMPT_SEQ_BUCKET_STEP: 1024 ")

attn_metadata = attn_metadata._replace(use_window_sdpa=use_window_sdpa)
attn_metadata = attn_metadata._replace(
sliding_window_right=self.sliding_window_right)
return attn_metadata

def _update_metadata(self,
Expand Down Expand Up @@ -609,7 +618,6 @@ def _update_metadata(self,
attn_metadata = self._set_attn_bias_for_sliding_window(
attn_metadata, batch_size, seq_len,
self.interleaved_sliding_window, device, dtype)

else:
attn_metadata = self._set_block_mapping(attn_metadata, batch_size,
device, dtype, False)
Expand Down Expand Up @@ -1482,17 +1490,6 @@ def move_to_device(self, tensor):
return tensor if tensor is None else tensor.to(self.device,
non_blocking=True)

def _get_position_pad(self) -> int:
"""
For gemma3 models,
due to the Hack in Gemma3ForConditionalGeneration::prepare_attn_masks,
'0' can't be used as pad for input position tensor.
In case, it might have '0's for bucketing, those '0' will be counted as
new sequence in the prepare_attn_masks() which is wrong.
"""
model_type = getattr(self.model_config.hf_config, 'model_type', '')
return -1 if model_type == 'gemma3' else 0

def add_vision_buckets_to_mrope_mm_optimized(self):
if self.mm_registry is not None:
model = self.get_model()
Expand Down Expand Up @@ -1748,11 +1745,11 @@ def _prepare_prompt(
make_mrope_positions_tensor_with_pad(input_positions=input_positions,
input_mrope_positions=input_mrope_positions,
max_prompt_len=max_prompt_len,
pad=self._get_position_pad())
pad=0)
else:
input_positions = make_cpu_tensor(input_positions,
max_len=max_prompt_len,
pad=self._get_position_pad(),
pad=0,
dtype=torch.long,
flat=self.use_merged_prefill)

Expand Down Expand Up @@ -2662,6 +2659,7 @@ def trim_attn_metadata(self, metadata: AttentionMetadata) -> object:
'window_block_groups',
'window_attn_bias',
'use_window_sdpa',
'sliding_window_right',
])
return attention_metadata

Expand Down