Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions tests/full_tests/ci_e2e_discoverable_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -314,12 +314,12 @@ run_gsm8k_deepseek_test() {


# GSM8K on deepseek v2 lite + unified attn
run_gsm8k_deepseek_unified_mla_test() {
echo "➡️ Testing GSM8K on deepseek v2 lite + Unified MLA..."
VLLM_UNIFIED_ATTN=true VLLM_SKIP_WARMUP=True PT_HPU_LAZY_MODE=1 \
pytest -v -s "${VLLM_GAUDI_PREFIX}/tests/models/language/generation/test_common.py" --model_card_path "${VLLM_GAUDI_PREFIX}/tests/full_tests/model_cards/DeepSeek-V2-Lite-chat.yaml"
echo "✅ GSM8K Test with deepseek v2 lite + Unified MLA passed."
}
#run_gsm8k_deepseek_unified_mla_test() {
# echo "➡️ Testing GSM8K on deepseek v2 lite + Unified MLA..."
# VLLM_UNIFIED_ATTN=true VLLM_SKIP_WARMUP=True PT_HPU_LAZY_MODE=1 \
# pytest -v -s "${VLLM_GAUDI_PREFIX}/tests/models/language/generation/test_common.py" --model_card_path "${VLLM_GAUDI_PREFIX}/tests/full_tests/model_cards/DeepSeek-V2-Lite-chat.yaml"
# echo "✅ GSM8K Test with deepseek v2 lite + Unified MLA passed."
#}

# GSM8K on QWEN3-30B-A3B
run_gsm8k_qwen3_30b_test() {
Expand Down Expand Up @@ -478,7 +478,7 @@ launch_all_tests() {
run_gsm8k_granite_async_test
run_gsm8k_granite_test_unified_attn_async
run_gsm8k_deepseek_test
run_gsm8k_deepseek_unified_mla_test
#run_gsm8k_deepseek_unified_mla_test
run_gsm8k_qwen3_30b_test
run_spec_decode_ngram_test
run_spec_decode_eagle3_test
Expand Down
89 changes: 54 additions & 35 deletions tests/unit_tests/multimodal/test_hpu_multimodal_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,35 +16,50 @@
)


def _dummy_items_from_tensors(tensors: NestedTensors, modalitity: str = "image"):
def _dummy_items_from_tensors(tensors: NestedTensors, modality: str = "image"):
"""
Creates MultiModalKwargsItems from a dict of modality to tensor.
Creates MultiModalKwargsItems from a list of tensors.
"""
elems = [
MultiModalFieldElem(modality=modalitity, key="key", data=t, field=MultiModalBatchedField()) for t in tensors
items = [
MultiModalKwargsItem({"key": MultiModalFieldElem(data=t, field=MultiModalBatchedField())}) for t in tensors
]
items = [MultiModalKwargsItem({"key": elem}) for elem in elems]
mm_items = MultiModalKwargsItems.from_seq(items)
mm_items = MultiModalKwargsItems({modality: items})
return mm_items


def _dummy_items_from_tensor_modalities(modality_tensor_dict: NestedTensors):
elems = [
MultiModalFieldElem(modality=modality, key="key", data=t, field=MultiModalBatchedField())
for modality, tensors in modality_tensor_dict.items() for t in tensors
]
items = [MultiModalKwargsItem({"key": elem}) for elem in elems]
mm_items = MultiModalKwargsItems.from_seq(items)
"""
Creates MultiModalKwargsItems from a dict of modality to list of tensors.
"""
items_by_modality = {}
for modality, tensors in modality_tensor_dict.items():
items = [
MultiModalKwargsItem({"key": MultiModalFieldElem(data=t, field=MultiModalBatchedField())}) for t in tensors
]
items_by_modality[modality] = items
mm_items = MultiModalKwargsItems(items_by_modality)
return mm_items


def _dummy_items_from_tensor_keys(key_tensor_dict: NestedTensors, modality: str = "image"):
elems = [
MultiModalFieldElem(modality=modality, key=key, data=t, field=MultiModalBatchedField())
for key, tensors in key_tensor_dict.items() for t in tensors
]
items = [MultiModalKwargsItem({elem.key: elem}) for elem in elems]
mm_items = MultiModalKwargsItems.from_seq(items)
def _dummy_items_from_tensor_keys(key_tensor_dict: dict[str, list], modality: str = "image"):
"""
Creates MultiModalKwargsItems from a dict of key names to list of tensors.
Creates items where each position combines tensors from all keys at that index.
For example: {"key1": [t1, t2], "key2": [t3, t4]} creates:
- Item 0: {key1: t1, key2: t3}
- Item 1: {key1: t2, key2: t4}
"""
# Get the number of items (should be same length for all keys)
num_items = len(next(iter(key_tensor_dict.values())))

items = []
for i in range(num_items):
item_dict = {}
for key, tensors in key_tensor_dict.items():
item_dict[key] = MultiModalFieldElem(data=tensors[i], field=MultiModalBatchedField())
items.append(MultiModalKwargsItem(item_dict))

mm_items = MultiModalKwargsItems({modality: items})
return mm_items


Expand All @@ -58,19 +73,13 @@ def assert_nested_tensors_equal_hpu(expected: NestedTensors, actual: NestedTenso
assert_nested_tensors_equal_hpu(expected_item, actual_item)


def assert_multimodal_kwargs_items_equal_hpu(expected_elems: MultiModalKwargsItem, actual: dict[str, NestedTensors]):
def assert_multimodal_kwargs_items_equal_hpu(expected: dict[str, NestedTensors], actual: dict[str, NestedTensors]):
"""HPU-aware assertion for multimodal input equality."""

assert set(expected_elems.keys()) == set(actual.keys())

for key in expected_elems:
if isinstance(expected_elems[key], list):
assert len(expected_elems[key]) == len(actual[key])
for expected_item, actual_item in zip(expected_elems[key], actual[key]):
assert_nested_tensors_equal_hpu(expected_item, actual_item)
continue
assert set(expected.keys()) == set(actual.keys())

assert_nested_tensors_equal_hpu(expected_elems[key], actual[key].data)
for key in expected:
assert_nested_tensors_equal_hpu(expected[key], actual[key])


@pytest.mark.parametrize(
Expand Down Expand Up @@ -200,7 +209,7 @@ def test_hpu_device_mismatch_handling(tensor_shapes):
result = dummy_kwargs_items.get_data()
expected = {"key": [hpu_tensor, cpu_tensor]}
# If successful, verify structure
assert_multimodal_kwargs_items_equal_hpu(result, expected)
assert_multimodal_kwargs_items_equal_hpu(expected, result)
except (RuntimeError, ValueError) as e:
# Expected behavior for device mismatch
assert "device" in str(e).lower() or "hpu" in str(e).lower()
Expand Down Expand Up @@ -240,19 +249,29 @@ def test_hpu_tensor_batching_sizes(tensor_size, batch_count):


def test_hpu_multiple_modalities():
"""Test MultiModalKwargsItems key handling."""
"""Test MultiModalKwargsItems handling of multiple modalities with different keys."""
device = "hpu"

# Test multiple modalities
# Test multiple modalities - each should have its own key
# This simulates a realistic scenario where different modalities
# produce different output keys (e.g., pixel_values for images, audio_features for audio)
image_tensor = torch.rand([3, 224, 224], device=device, dtype=torch.bfloat16)
audio_tensor = torch.rand([1000], device=device, dtype=torch.bfloat16)

batch_data = {"image": [image_tensor], "audio": [audio_tensor]}
# Create items with different keys for different modalities
image_items = [
MultiModalKwargsItem({"pixel_values": MultiModalFieldElem(data=image_tensor, field=MultiModalBatchedField())})
]
audio_items = [
MultiModalKwargsItem({"audio_features": MultiModalFieldElem(data=audio_tensor, field=MultiModalBatchedField())})
]

dummy_kwargs_items = MultiModalKwargsItems({"image": image_items, "audio": audio_items})

dummy_kwargs_items: MultiModalKwargsItems = _dummy_items_from_tensor_modalities(batch_data)
result = dummy_kwargs_items.get_data()

expected = {"key": [image_tensor, audio_tensor]}
# Each modality should have its own key in the output
expected = {"pixel_values": image_tensor.unsqueeze(0), "audio_features": audio_tensor.unsqueeze(0)}

assert_multimodal_kwargs_items_equal_hpu(expected, result)

Expand Down
1 change: 0 additions & 1 deletion tests/unit_tests/test_prefix_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
def get_vllm_config():
model_config = ModelConfig(
model="facebook/opt-125m",
task="generate",
tokenizer="facebook/opt-125m",
tokenizer_mode="auto",
trust_remote_code=True,
Expand Down
1 change: 0 additions & 1 deletion tests/unit_tests/worker/test_hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ def initialize_kv_cache(runner: HPUModelRunner):
def get_vllm_config():
model_config = ModelConfig(
model="facebook/opt-125m",
task="generate",
tokenizer="facebook/opt-125m",
tokenizer_mode="auto",
trust_remote_code=True,
Expand Down
2 changes: 2 additions & 0 deletions vllm_gaudi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ def register_utils():


def register_ops():
"""Register custom PluggableLayers for the HPU platform"""
import vllm_gaudi.attention.oot_mla # noqa: F401
"""Register custom ops for the HPU platform."""
import vllm_gaudi.v1.sample.hpu_rejection_sampler # noqa: F401
import vllm_gaudi.distributed.kv_transfer.kv_connector.v1.hpu_nixl_connector # noqa: F401
Expand Down
68 changes: 14 additions & 54 deletions vllm_gaudi/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from vllm.v1.attention.backend import (AttentionBackend, AttentionImpl, AttentionLayer, AttentionMetadata,
AttentionType)
from vllm.model_executor.layers.attention.mla_attention import MLACommonImpl
from vllm.model_executor.layers.attention.mla_attention import (MLACommonImpl)
from vllm_gaudi.attention.ops.hpu_paged_attn import (HPUPagedAttention, HPUPagedAttentionMetadata,
HPUPagedAttentionMetadataBuilder)

Expand Down Expand Up @@ -281,53 +281,16 @@ def __init__(
f"heads in the layer. Sinks shape: {sinks.shape}, "
f"num_heads: {num_heads}.")

def forward(
self,
layer: AttentionLayer,
q: torch.Tensor,
k_c_normed: torch.Tensor, # key in unified attn
k_pe: torch.Tensor, # value in unified attn
kv_cache: torch.Tensor,
attn_metadata: HPUAttentionMetadata,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if output is not None:
raise NotImplementedError("output is not yet supported for MLAImplBase")

is_prefill = attn_metadata.is_prompt

if not is_prefill:
# decode
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
# Convert from (B, N, P) to (N, B, P)
q_nope = q_nope.transpose(0, 1)
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
decode_ql_nope = torch.bmm(q_nope, self.W_UK_T)
# Convert from (N, B, L) to (B, N, L)
decode_ql_nope = decode_ql_nope.transpose(0, 1)

slot_mapping = attn_metadata.slot_mapping.flatten() if attn_metadata.slot_mapping is not None else None

latent_vec_k = torch.concat((k_c_normed, k_pe.view(*k_c_normed.shape[:-1], self.qk_rope_head_dim)), dim=-1)
latent_vec_k = latent_vec_k.view(-1, self.qk_rope_head_dim + self.kv_lora_rank)

# write the latent and rope to kv cache
if kv_cache is not None and len(kv_cache) >= 2:
self.latent_cache_k(latent_vec_k, kv_cache[0], slot_mapping)
k_cache = kv_cache[0]

if is_prefill:
return self._forward_prefill(q, latent_vec_k, k_cache, attn_metadata)
else:
return self._forward_decode(decode_ql_nope, q_pe, k_cache, attn_metadata)

def _forward_prefill( # type: ignore
def forward_mha( # type: ignore
self, q: torch.Tensor, latent_vec_k: torch.Tensor, k_cache: torch.Tensor,
attn_metadata: HPUAttentionMetadata) -> torch.Tensor:

##### get prefix cache #####
if attn_metadata.block_list is not None:
current = latent_vec_k
# Patch for vllm-gaudi kv_cache tuple format.
if isinstance(k_cache, tuple):
k_cache = k_cache[0] # Use only key_cache for MLA
past = self.latent_cache_k.fetch_from_cache(k_cache.unflatten(0, (-1, attn_metadata.block_size)),
attn_metadata.block_list)
past = past.view(-1, past.shape[-1])
Expand Down Expand Up @@ -382,9 +345,14 @@ def _forward_prefill( # type: ignore

return output.reshape(-1, self.num_heads * v.shape[-1])

def _forward_decode( # type: ignore
def forward_mqa( # type: ignore
self, q_nope: torch.Tensor, q_pe: torch.Tensor, k_cache: torch.Tensor,
attn_metadata: HPUAttentionMetadata) -> torch.Tensor:
if k_cache is not None and isinstance(k_cache, tuple):
key_cache, value_cache, k_scales, v_scales = \
HPUPagedAttention.split_kv_cache(k_cache, self.num_kv_heads, self.head_size)
if isinstance(k_cache, tuple):
k_cache = k_cache[0] # Use only key_cache for MLA
query = torch.cat([q_nope, q_pe], dim=-1)
key_cache = k_cache.unsqueeze(1)
value_cache = None
Expand All @@ -404,8 +372,7 @@ def _forward_decode( # type: ignore
keys_fetch_func=self.latent_cache_k.fetch_from_cache,
values_fetch_func=None,
kv_lora_rank=self.kv_lora_rank)
result = self._v_up_proj(output)
return result
return output

# NOTE(Xinyu): Make the loaded weight contiguous to avoid the transpose
# during each graph execution
Expand Down Expand Up @@ -1228,15 +1195,8 @@ def get_supported_head_sizes() -> list[int]:
def is_mla(cls) -> bool:
return True

def _forward_decode(self, *args, **kwargs) -> torch.Tensor:
def forward_mqa(self, *args, **kwargs) -> torch.Tensor:
raise NotImplementedError("Use forward method for HPUUnifiedMLAImpl")

def _forward_prefill(self, *args, **kwargs) -> torch.Tensor:
def forward_mha(self, *args, **kwargs) -> torch.Tensor:
raise NotImplementedError("Use forward method for HPUUnifiedMLAImpl")

def process_weights_after_loading(self, act_dtype: torch.dtype):
# Parent MLACommonImpl extracts W_UV and W_UK_T from kv_b_proj weights
# These projection matrices are used for latent ↔ full space conversions
super().process_weights_after_loading(act_dtype)
self.W_UV: torch.Tensor = self.W_UV.contiguous()
self.W_UK_T: torch.Tensor = self.W_UK_T.contiguous()
Loading