Skip to content

[FIX_FOR_VLLM_CUSTOM=d2f4a71cd54418369f617a174e6c839a71a47ed8] Hourly fixes – batch no. 1#988

Merged
iboiko-habana merged 14 commits intovllm-project:mainfrom
pawel-olejniczak:dev/polejnix/fix_batch_1
Feb 19, 2026
Merged

[FIX_FOR_VLLM_CUSTOM=d2f4a71cd54418369f617a174e6c839a71a47ed8] Hourly fixes – batch no. 1#988
iboiko-habana merged 14 commits intovllm-project:mainfrom
pawel-olejniczak:dev/polejnix/fix_batch_1

Conversation

@pawel-olejniczak
Copy link
Copy Markdown
Contributor

This PR contains part of fixes from #903

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR ports a subset of fixes from vllm-gaudi PR #903 to keep the Gaudi v1 worker, multimodal plumbing, and MLA attention integration aligned with upstream vLLM API changes.

Changes:

  • Replace deprecated vllm.envs profiler-dir access with os.getenv and update related logging.
  • Update multimodal encoder batching/caching and unit tests to reflect upstream MultiModalKwargsItem(s) structure changes.
  • Refactor HPU MLA attention integration to match upstream method naming/signatures and patch tuple-based kv_cache handling.

Reviewed changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 7 comments.

Show a summary per file
File Description
vllm_gaudi/v1/worker/hpu_worker.py Switches torch-profiler directory handling off vllm.envs and updates initialization logic.
vllm_gaudi/v1/worker/hpu_model_runner.py Adjusts multimodal encoder batching/caching and embedding gathering to match upstream multimodal changes.
vllm_gaudi/ops/hpu_fused_moe.py Adds _select_monolithic() override for fused MoE behavior selection.
vllm_gaudi/attention/backends/hpu_attn.py Refactors MLA attention implementation hooks and adds monkeypatches for upstream API alignment / kv_cache tuple format.
tests/unit_tests/worker/test_hpu_model_runner.py Updates ModelConfig construction (removes task="generate").
tests/unit_tests/test_prefix_caching.py Updates ModelConfig construction (removes task="generate").
tests/unit_tests/multimodal/test_hpu_multimodal_inputs.py Updates test helpers/expectations to new multimodal kwargs item structure and modality handling.
Comments suppressed due to low confidence (1)

vllm_gaudi/attention/backends/hpu_attn.py:360

  • forward_mqa() will crash if k_cache is None (e.g., during profiling/initialization paths) because it unconditionally does k_cache.unsqueeze(1). Either assert k_cache is not None with a clear error, or handle the None case gracefully. Also, the split_kv_cache(...) result is currently unused—remove it or use the split outputs as intended.
    def forward_mqa(  # type: ignore
            self, q_nope: torch.Tensor, q_pe: torch.Tensor, k_cache: torch.Tensor,
            attn_metadata: HPUAttentionMetadata) -> torch.Tensor:
        if k_cache is not None and isinstance(k_cache, tuple):
            key_cache, value_cache, k_scales, v_scales = \
                HPUPagedAttention.split_kv_cache(k_cache, self.num_kv_heads, self.head_size)
        if isinstance(k_cache, tuple):
            k_cache = k_cache[0]  # Use only key_cache for MLA
        query = torch.cat([q_nope, q_pe], dim=-1)
        key_cache = k_cache.unsqueeze(1)
        value_cache = None
        output = HPUPagedAttention.forward_decode(query=query,

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread vllm_gaudi/v1/worker/hpu_worker.py Outdated
Comment thread vllm_gaudi/v1/worker/hpu_model_runner.py Outdated
Comment on lines +1517 to +1518
is_mm_embed[req_start_pos+start_idx:req_start_pos + end_idx] \
= True
Copy link

Copilot AI Feb 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_mm_embed[...] = True ignores the is_embed mask for partial/placeholder ranges. When pos_info.is_embed is not None, this will mark non-embedding placeholder positions as multimodal and can cause a mismatch between is_mm_embed.sum() and the flattened multimodal embedding count (see _merge_multimodal_embeddings raising on count mismatch). Set this slice to is_embed (or equivalent) when present, not unconditional True.

Suggested change
is_mm_embed[req_start_pos+start_idx:req_start_pos + end_idx] \
= True
slice_start = req_start_pos + start_idx
slice_end = req_start_pos + end_idx
if is_embed is None:
is_mm_embed[slice_start:slice_end] = True
else:
is_mm_embed[slice_start:slice_end] = is_embed

Copilot uses AI. Check for mistakes.
Comment on lines 1507 to +1515
if (is_embed := pos_info.is_embed) is not None:
is_embed = is_embed[start_idx:end_idx]
mm_embeds_item = encoder_output[curr_embeds_start:curr_embeds_end]
else:
mm_embeds_item = encoder_output[start_idx:end_idx]

sliced_output = encoder_output[start_idx:end_idx]
mm_embeds_item = sliced_output if is_embed is None else sliced_output[is_embed]

Copy link

Copilot AI Feb 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Redundant logic: mm_embeds_item is computed in the if (is_embed := pos_info.is_embed) ... else ... block and then immediately overwritten by sliced_output/sliced_output[is_embed]. This makes the earlier branch dead code and obscures which indices are intended (curr_embeds_start/curr_embeds_end vs start_idx/end_idx). Please remove the dead assignment and keep a single, clearly-correct slicing path.

Copilot uses AI. Check for mistakes.
Comment thread vllm_gaudi/attention/backends/hpu_attn.py Outdated
Comment on lines +827 to +829
def __init__(self):
super(MLAAttention, self).__init__()
self.latent_cache_k = VLLMKVCache()
Copy link

Copilot AI Feb 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HPUMLAAttention.__init__ calls super(MLAAttention, self).__init__(), which skips MLAAttention.__init__ and is almost certainly not what you want. Use super().__init__() (or remove this class entirely if it only exists to source methods for monkeypatching).

Copilot uses AI. Check for mistakes.
) -> torch.Tensor:
if output is not None:
raise NotImplementedError("output is not yet supported for MLAImplBase")
self.latent_cache_k = VLLMKVCache()
Copy link

Copilot AI Feb 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the monkeypatched forward_impl, self.latent_cache_k = VLLMKVCache() is assigned every call but is not referenced afterwards (the cache used is self.impl.latent_cache_k). This looks like dead code at best, and at worst could inadvertently reset state if upstream MLAAttention starts using self.latent_cache_k. Please remove it or ensure the cache reset happens on the correct object.

Suggested change
self.latent_cache_k = VLLMKVCache()

Copilot uses AI. Check for mistakes.
@pawel-olejniczak pawel-olejniczak force-pushed the dev/polejnix/fix_batch_1 branch 2 times, most recently from 2986e56 to f87d2f9 Compare February 18, 2026 14:18
Signed-off-by: Paweł Olejniczak <polejniczakx@habana.ai>
Signed-off-by: Paweł Olejniczak <polejniczakx@habana.ai>
Signed-off-by: Paweł Olejniczak <polejniczakx@habana.ai>
Signed-off-by: Paweł Olejniczak <polejniczakx@habana.ai>
…aceholders

Signed-off-by: Paweł Olejniczak <polejniczakx@habana.ai>
Signed-off-by: Paweł Olejniczak <polejniczakx@habana.ai>
Signed-off-by: Paweł Olejniczak <polejniczakx@habana.ai>
Signed-off-by: Paweł Olejniczak <polejniczakx@habana.ai>
Signed-off-by: Paweł Olejniczak <polejniczakx@habana.ai>
Signed-off-by: Paweł Olejniczak <polejniczakx@habana.ai>
Signed-off-by: Paweł Olejniczak <polejniczakx@habana.ai>
Signed-off-by: Paweł Olejniczak <polejniczakx@habana.ai>
Signed-off-by: Paweł Olejniczak <polejniczakx@habana.ai>
Signed-off-by: Paweł Olejniczak <polejniczakx@habana.ai>
@github-actions
Copy link
Copy Markdown

✅ CI Passed

All checks passed successfully against the following vllm commit:
d2f4a71cd54418369f617a174e6c839a71a47ed8

@iboiko-habana iboiko-habana merged commit c79d6fb into vllm-project:main Feb 19, 2026
64 checks passed
SKRohit pushed a commit to SKRohit/vllm-gaudi that referenced this pull request Feb 20, 2026
… fixes – batch no. 1 (vllm-project#988)

This PR contains part of fixes from
vllm-project#903

---------

Signed-off-by: Paweł Olejniczak <polejniczakx@habana.ai>
Signed-off-by: Rohit kumar Singh <rksingh@habana.ai>
gyou2021 pushed a commit to gyou2021/vllm-gaudi that referenced this pull request Feb 21, 2026
… fixes – batch no. 1 (vllm-project#988)

This PR contains part of fixes from
vllm-project#903

---------

Signed-off-by: Paweł Olejniczak <polejniczakx@habana.ai>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants