Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/configuration/env_variables.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,9 @@ and warm-up. Recommended settings for this case are:

!!! note
If the model config specifies a high `max_model_len`, set it to the sum of `input_tokens` and `output_tokens`, rounded up to a multiple of `block_size` according to actual requirements.

## Additional Performance Tuning Parameters for the Attention Kernel

| Parameter name | Description | Default value |
| ---------------------------------------- | -------------------------------------------------------------------------------------------- | ------------------------------------------ |
| `VLLM_USE_BOOLEAN_MASK` | Use boolean attention mask instead of float ones. | `False` |
29 changes: 29 additions & 0 deletions tests/full_tests/ci_e2e_discoverable_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,9 @@ run_gsm8k_granite_test() {
echo "➡️ Testing GSM8K on granite-8b..."
VLLM_CONTIGUOUS_PA=False VLLM_SKIP_WARMUP=True PT_HPU_LAZY_MODE=1 \
pytest -v -s "${VLLM_GAUDI_PREFIX}/tests/models/language/generation/test_common.py" --model_card_path "${VLLM_GAUDI_PREFIX}/tests/full_tests/model_cards/granite-8b.yaml"

VLLM_USE_BOOLEAN_MASK=true VLLM_CONTIGUOUS_PA=False VLLM_SKIP_WARMUP=True PT_HPU_LAZY_MODE=1 ASYNC_SCHEDULING=1 \
pytest -v -s "${VLLM_GAUDI_PREFIX}/tests/models/language/generation/test_common.py" --model_card_path "${VLLM_GAUDI_PREFIX}/tests/full_tests/model_cards/granite-8b.yaml"
echo "✅ Test with granite-8b passed."
}

Expand Down Expand Up @@ -333,6 +336,9 @@ run_gsm8k_deepseek_test() {
echo "➡️ Testing GSM8K on deepseek v2 lite..."
VLLM_CONTIGUOUS_PA=False VLLM_SKIP_WARMUP=True PT_HPU_LAZY_MODE=1 \
pytest -v -s "${VLLM_GAUDI_PREFIX}/tests/models/language/generation/test_common.py" --model_card_path "${VLLM_GAUDI_PREFIX}/tests/full_tests/model_cards/DeepSeek-V2-Lite-chat.yaml"

VLLM_USE_BOOLEAN_MASK=true VLLM_CONTIGUOUS_PA=False VLLM_SKIP_WARMUP=True PT_HPU_LAZY_MODE=1 \
pytest -v -s "${VLLM_GAUDI_PREFIX}/tests/models/language/generation/test_common.py" --model_card_path "${VLLM_GAUDI_PREFIX}/tests/full_tests/model_cards/DeepSeek-V2-Lite-chat.yaml"
echo "✅ GSM8K Test with deepseek v2 lite passed."
}

Expand All @@ -350,9 +356,31 @@ run_gsm8k_qwen3_30b_test() {
echo "➡️ Testing GSM8K on QWEN3-30B-A3B..."
VLLM_CONTIGUOUS_PA=False VLLM_SKIP_WARMUP=True PT_HPU_LAZY_MODE=1 TP_SIZE=2 \
pytest -v -s "${VLLM_GAUDI_PREFIX}/tests/models/language/generation/test_common.py" --model_card_path "${VLLM_GAUDI_PREFIX}/tests/full_tests/model_cards/Qwen3-30B-A3B.yaml"

VLLM_USE_BOOLEAN_MASK=true VLLM_CONTIGUOUS_PA=False VLLM_SKIP_WARMUP=True PT_HPU_LAZY_MODE=1 TP_SIZE=2 \
pytest -v -s "${VLLM_GAUDI_PREFIX}/tests/models/language/generation/test_common.py" --model_card_path "${VLLM_GAUDI_PREFIX}/tests/full_tests/model_cards/Qwen3-30B-A3B.yaml"
echo "✅ Test with QWEN3-30B-A3B passed."
}

# GSM8K on Gemma3
run_gsm8k_gemma3_test() {
echo "➡️ Testing GSM8K on gemma-3-4b-it..."
VLLM_CONTIGUOUS_PA=False VLLM_SKIP_WARMUP=true PT_HPU_LAZY_MODE=1 \
pytest -v -s "${VLLM_GAUDI_PREFIX}/tests/models/language/generation/test_common.py" --model_card_path "${VLLM_GAUDI_PREFIX}/tests/models/language/generation/model_cards/gemma-3-4b-it.yaml"

VLLM_USE_BOOLEAN_MASK=true VLLM_CONTIGUOUS_PA=False VLLM_SKIP_WARMUP=true PT_HPU_LAZY_MODE=1 \
pytest -v -s "${VLLM_GAUDI_PREFIX}/tests/models/language/generation/test_common.py" --model_card_path "${VLLM_GAUDI_PREFIX}/tests/models/language/generation/model_cards/gemma-3-4b-it.yaml"
echo "✅ Test with gemma-3-4b-it passed."

echo "➡️ Testing gemma-3-27b-it..."
VLLM_CONTIGUOUS_PA=False VLLM_SKIP_WARMUP=true PT_HPU_LAZY_MODE=1 \
pytest -v -s "${VLLM_GAUDI_PREFIX}/tests/models/language/generation/test_common.py" --model_card_path "${VLLM_GAUDI_PREFIX}/tests/models/language/generation/model_cards/gemma-3-27b-it.yaml"

VLLM_USE_BOOLEAN_MASK=true VLLM_CONTIGUOUS_PA=False VLLM_SKIP_WARMUP=true PT_HPU_LAZY_MODE=1 \
pytest -v -s "${VLLM_GAUDI_PREFIX}/tests/models/language/generation/test_common.py" --model_card_path "${VLLM_GAUDI_PREFIX}/tests/models/language/generation/model_cards/gemma-3-27b-it.yaml"
echo "✅ Test with gemma-3-27b-it passed."
}


# GSM8K on Qwen3.5-9B
# TODO once Qwen3.5-35B-A3B compile time is improved, replace this test.
Expand Down Expand Up @@ -545,6 +573,7 @@ launch_all_tests() {
run_gsm8k_deepseek_test
#run_gsm8k_deepseek_unified_mla_test
run_gsm8k_qwen3_30b_test
run_gsm8k_gemma3_test
run_preemption_test
run_spec_decode_ngram_test
run_spec_decode_eagle3_test
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
model_card:
model_name: google/gemma-3-27b-it
tasks: gsm8k
num_fewshot: 8
limit: 256

metrics:
name: exact_match,strict-match
value: 0.85
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
model_card:
model_name: google/gemma-3-4b-it
tasks: gsm8k
num_fewshot: 8
limit: 256

metrics:
name: exact_match,strict-match
value: 0.76
1 change: 1 addition & 0 deletions vllm_gaudi/extension/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,5 +129,6 @@ def get_features():
All(VersionRange(">=1.24.0.460"), MinPackageVersion("neural_compressor_pt", "3.7")),
env_var_type=boolean),
Value('use_hpu_aligned_scale', False, env_var='HPU_ALIGNED_SCALE', env_var_type=boolean),
Value('use_boolean_mask', False, env_var='VLLM_USE_BOOLEAN_MASK', env_var_type=boolean),
Comment thread
yangulei marked this conversation as resolved.
Comment thread
yangulei marked this conversation as resolved.
]
return split_values_and_flags(features)
22 changes: 14 additions & 8 deletions vllm_gaudi/extension/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def pipelined_pa(attn, value, block_bias, block_groups, block_mapping, sink, bat
batch2block_matmul_op, block2batch_matmul_op):
# When fp32_softmax is enabled attn is left in fp32 after Q@K
# We can return to native dtype after we renormalize and calculate the adjustments
if block_bias is not None and attn.dtype != block_bias.dtype:
if block_bias is not None and block_bias.dtype != torch.bool and attn.dtype != block_bias.dtype:
block_bias = block_bias.to(dtype=attn.dtype)
# TODO: w/a with 5D req as the block_softmax kernel does not support 4D attn tensor, which is used in e.g. Granite-3B
if get_config().fused_block_softmax and get_config().fused_block_softmax_adjustment and attn.dim() == 5:
Expand All @@ -89,7 +89,10 @@ def pipelined_pa(attn, value, block_bias, block_groups, block_mapping, sink, bat
attn = attn.to(value.dtype)
else:
if block_bias is not None:
attn.add_(block_bias)
if block_bias.dtype == torch.bool:
attn.masked_fill_(~block_bias, -math.inf)
else:
attn.add_(block_bias)
Comment thread
yangulei marked this conversation as resolved.
block_max = attn.amax(dim=-1, keepdim=True)
if sink is not None:
block_max = torch.maximum(block_max, sink)
Expand All @@ -109,9 +112,9 @@ def pipelined_pa(attn, value, block_bias, block_groups, block_mapping, sink, bat
attn_sink = attn_sink.exp()
if attn_sink.dtype == torch.float32:
attn_sink = attn_sink.to(value.dtype)
#TODO: Removing this .sum and using attn_sink directly
#results in wrong output which does not make sense.
#Looks like a Synapse issue, need to investigate further.
# TODO: Removing this .sum and using attn_sink directly
# results in wrong output which does not make sense.
# Looks like a Synapse issue, need to investigate further.
block_sums_sink = attn_sink.sum(dim=-1, keepdim=True)
block_sums = block_sums + block_sums_sink
attn = matmul_av_op(attn, value)
Expand Down Expand Up @@ -357,9 +360,12 @@ def _naive_prompt_attention(query: torch.Tensor,
htcore.mark_step()
attn_weights.add_(position_bias)
if attn_bias is not None:
if attn_weights.dtype != attn_bias.dtype:
attn_bias = attn_bias.to(dtype=attn_weights.dtype)
attn_weights.add_(attn_bias)
if attn_bias.dtype == torch.bool:
attn_weights.masked_fill_(~attn_bias, -math.inf)
else:
if attn_weights.dtype != attn_bias.dtype:
attn_bias = attn_bias.to(dtype=attn_weights.dtype)
attn_weights.add_(attn_bias)
if sinks is not None:
sink = sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1)
if query_heads != kv_heads:
Expand Down
73 changes: 51 additions & 22 deletions vllm_gaudi/v1/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -919,6 +919,7 @@ def __init__(
self.use_hybrid_cache = os.getenv('VLLM_USE_HYBRID_CACHE', 'false').strip().lower() in ("1", "true")
self.use_naive_mamba_cache_sharing = os.getenv('VLLM_USE_NAIVE_MAMBA_CACHE_SHARING',
'true').strip().lower() in ("1", "true")
self.use_boolean_mask = get_config().use_boolean_mask

# Lazy initialization
# self.model: nn.Module # set after load_model
Expand Down Expand Up @@ -2144,7 +2145,7 @@ def _make_attn_bias(self, context_groups, token_groups):
causal_mask = torch.ones(num_queries, num_queries, device='cpu', dtype=torch.bool)
causal_mask = torch.triu(causal_mask, diagonal=1).unsqueeze(0)
attn_mask[:, :, context_len:].logical_or_(causal_mask)
attn_mask = attn_mask.to(dtype).masked_fill_(attn_mask, -math.inf)
attn_mask = ~attn_mask if self.use_boolean_mask else attn_mask.to(dtype).masked_fill_(attn_mask, -math.inf)
Comment thread
yangulei marked this conversation as resolved.

return attn_mask.unflatten(0, (1, -1))

Expand Down Expand Up @@ -6727,6 +6728,8 @@ def __init__(
# int(os.getenv("PT_HPU_SDPA_BC_FACTOR", "1024"))
self.slice_thld = int(os.environ.get('VLLM_FUSEDSDPA_SLIDE_THLD', '8192'))

self.use_boolean_mask = get_config().use_boolean_mask

def _set_attn_bias(self, attn_metadata: HPUAttentionMetadataV1, batch_size: int, seq_len: int, device: torch.device,
dtype: torch.dtype) -> HPUAttentionMetadataV1:
"""
Expand Down Expand Up @@ -6771,7 +6774,10 @@ def _set_attn_bias(self, attn_metadata: HPUAttentionMetadataV1, batch_size: int,
diagonal=1)
mask = causal_mask.logical_or(len_mask)
mask = torch.concat((past_mask, mask), dim=-1)
attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf))
if self.use_boolean_mask:
attn_bias = ~mask
else:
attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf))
attn_metadata = custom_tuple_replace(prefill_metadata, "TrimmedAttentionMetadata", attn_bias=attn_bias)
return attn_metadata

Expand Down Expand Up @@ -6830,16 +6836,24 @@ def _set_attn_bias_for_sliding_window(self, attn_metadata: HPUAttentionMetadataV
# seq_lens_t.unsqueeze(-1)).view(batch_size, 1, 1, seq_len))
# causal_mask = causal_mask.logical_and(len_mask)

mask = torch.concat((past_mask, causal_mask), dim=-1)
attn_bias = torch.where(mask, torch.tensor(0.0, dtype=dtype, device=device),
torch.tensor(float('-inf'), dtype=dtype, device=device))
if self.use_boolean_mask:
attn_bias = torch.concat((past_mask, causal_mask), dim=-1)
else:
mask = torch.concat((past_mask, causal_mask), dim=-1)
attn_bias = torch.where(mask, torch.tensor(0.0, dtype=dtype, device=device),
torch.tensor(-math.inf, dtype=dtype, device=device))
else:
# CAUSAL MASK without removing padding (CAUSAL+sliding window)
# removing padding cause accuracy issue for images input
tensor = torch.full((batch_size, 1, seq_len, seq_len), device=device, dtype=dtype, fill_value=1)
mask = torch.tril(tensor, diagonal=shift)
mask = torch.triu(mask, diagonal=shift - window_size + 1)
attn_bias = torch.log(mask)
if self.use_boolean_mask:
tensor = torch.ones((batch_size, 1, seq_len, seq_len), device=device, dtype=torch.bool)
mask = torch.tril(tensor, diagonal=shift)
attn_bias = torch.triu(mask, diagonal=shift - window_size + 1)
else:
tensor = torch.full((batch_size, 1, seq_len, seq_len), device=device, dtype=dtype, fill_value=1)
mask = torch.tril(tensor, diagonal=shift)
mask = torch.triu(mask, diagonal=shift - window_size + 1)
attn_bias = torch.log(mask)

attn_metadata = custom_tuple_replace(prefill_metadata, "TrimmedAttentionMetadata", window_attn_bias=attn_bias)
return attn_metadata
Expand Down Expand Up @@ -6893,18 +6907,30 @@ def _set_attn_bias_for_chunked_attention(self, attn_metadata: HPUAttentionMetada
causal_mask = causal_mask & same_chunk_mask
causal_mask = causal_mask.unsqueeze(0).unsqueeze(0).expand(batch_size, 1, seq_len, seq_len)

mask = torch.concat((past_mask, causal_mask), dim=-1)
attn_bias = torch.where(mask, torch.tensor(0.0, dtype=dtype, device=device),
torch.tensor(float('-inf'), dtype=dtype, device=device))
if self.use_boolean_mask:
attn_bias = torch.concat((past_mask, causal_mask), dim=-1)
else:
mask = torch.concat((past_mask, causal_mask), dim=-1)
attn_bias = torch.where(mask, torch.tensor(0.0, dtype=dtype, device=device),
torch.tensor(float('-inf'), dtype=dtype, device=device))
else:
tensor = torch.full((batch_size, 1, seq_len, seq_len), device=device, dtype=dtype, fill_value=1)
mask = torch.tril(tensor, diagonal=shift)
idx = torch.arange(seq_len, device=device)
chunk_id = idx // chunk_size
same_chunk = chunk_id.unsqueeze(0) == chunk_id.unsqueeze(1)
same_chunk = same_chunk.unsqueeze(0).unsqueeze(0)
mask = torch.where(same_chunk, mask, torch.tensor(0.0, dtype=dtype, device=device))
attn_bias = torch.log(mask)
if self.use_boolean_mask:
tensor = torch.ones((batch_size, 1, seq_len, seq_len), device=device, dtype=torch.bool)
mask = torch.tril(tensor, diagonal=shift)
idx = torch.arange(seq_len, device=device)
chunk_id = idx // chunk_size
same_chunk = chunk_id.unsqueeze(0) == chunk_id.unsqueeze(1)
same_chunk = same_chunk.unsqueeze(0).unsqueeze(0)
attn_bias = same_chunk & mask
else:
tensor = torch.full((batch_size, 1, seq_len, seq_len), device=device, dtype=dtype, fill_value=1)
mask = torch.tril(tensor, diagonal=shift)
idx = torch.arange(seq_len, device=device)
chunk_id = idx // chunk_size
same_chunk = chunk_id.unsqueeze(0) == chunk_id.unsqueeze(1)
same_chunk = same_chunk.unsqueeze(0).unsqueeze(0)
mask = torch.where(same_chunk, mask, torch.tensor(0.0, dtype=dtype, device=device))
attn_bias = torch.log(mask)

attn_metadata = custom_tuple_replace(prefill_metadata, "TrimmedAttentionMetadata", chunked_attn_bias=attn_bias)
return attn_metadata
Expand Down Expand Up @@ -6944,8 +6970,11 @@ def _set_block_mapping(self,

block_size = getattr(metadata, "block_size", self.block_size)
mask = torch.arange(0, block_size, device=device, dtype=torch.int32).unsqueeze(0)
mask = mask >= block_usage.unsqueeze(-1)
attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf))
if self.use_boolean_mask:
attn_bias = mask < block_usage.unsqueeze(-1)
else:
mask = mask >= block_usage.unsqueeze(-1)
attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf))

if not is_fake_hpu():
block_mapping = torch.nn.functional.one_hot(block_groups, num_classes=batch_size)
Expand Down
Loading