Skip to content

Mrope accuracy fix for qwen#1437

Merged
adobrzyn merged 4 commits into
vllm-project:mainfrom
hsubramony:mrope_fix
May 15, 2026
Merged

Mrope accuracy fix for qwen#1437
adobrzyn merged 4 commits into
vllm-project:mainfrom
hsubramony:mrope_fix

Conversation

@hsubramony
Copy link
Copy Markdown
Contributor

@hsubramony hsubramony commented May 11, 2026

When mrope_interleaved is enabled, HPUMRotaryEmbedding was still using the non-interleaved split/concat section mapping for cos/sin.
This produced incorrect rotary channel ordering for multimodal MRoPE inputs and could cause sample-level mismatches against upstream vLLM behavior.
Use apply_interleaved_rope for the interleaved branch, and preserve the existing split/concat logic for non-interleaved layouts.

When mrope_interleaved is enabled, HPUMRotaryEmbedding was still using the non-interleaved split/concat section mapping for cos/sin.
This produced incorrect rotary channel ordering for multimodal MRoPE inputs and could cause sample-level mismatches against upstream vLLM behavior.
Use apply_interleaved_rope for the interleaved branch, and preserve the existing split/concat logic for non-interleaved layouts.

Co-authored-by: Jimin Ha <jimin.ha@intel.com>
Signed-off-by: Harish Subramony <harish.subramony@intel.com>
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR updates the HPU implementation of MRotaryEmbedding to improve multimodal MRoPE correctness/accuracy (notably for Qwen-style 3-axis positions) by changing how cos/sin values are assembled for the rotary kernel.

Changes:

  • Add an optional prepare_mrope_cache mode to precompute three sparse cos/sin caches (per T/H/W axis) that can be combined by addition.
  • Normalize/reshape incoming positions more defensively (handling [3, seq_len, 1], [1, seq_len], and flattened forms).
  • Switch the multimodal [3, seq_len] path from per-step split/concat to cache-sum assembly.

Comment thread vllm_gaudi/ops/hpu_rotary_embedding.py Outdated
Comment on lines +639 to +654
assert self.mrope_section

sin_start_idx = self.rotary_dim // 2
if getattr(self, "mrope_interleaved", False):
mrope1_slice = (
list(range(1, self.mrope_section[1] * 3, 3))
+ list(range(sin_start_idx + 1, sin_start_idx + self.mrope_section[1] * 3, 3)))
mrope2_slice = (
list(range(2, self.mrope_section[2] * 3, 3))
+ list(range(sin_start_idx + 2, sin_start_idx + self.mrope_section[2] * 3, 3)))
else:
c0 = self.mrope_section[0]
c1 = c0 + self.mrope_section[1]
c2 = c1 + self.mrope_section[2]
mrope1_slice = list(range(c0, c1)) + list(range(sin_start_idx + c0, sin_start_idx + c1))
mrope2_slice = list(range(c1, c2)) + list(range(sin_start_idx + c1, sin_start_idx + c2))
Comment thread vllm_gaudi/ops/hpu_rotary_embedding.py Outdated
cos = torch.cat((cos, cos), dim=-1).unsqueeze(-2)
sin = torch.cat((sin, sin), dim=-1).unsqueeze(-2)
if offsets is not None:
offsets = offsets.view(positions.shape[0], -1)
Comment thread vllm_gaudi/ops/hpu_rotary_embedding.py Outdated
Comment on lines +713 to +714
if positions.shape[0] != num_tokens:
positions = positions.view(-1, num_tokens)
Comment thread vllm_gaudi/ops/hpu_rotary_embedding.py Outdated
Comment on lines +656 to +672
self.cos_sin_cache_mrope1 = torch.zeros_like(self.cos_sin_cache)
self.cos_sin_cache_mrope2 = torch.zeros_like(self.cos_sin_cache)
self.cos_sin_cache_mrope1[..., mrope1_slice] = self.cos_sin_cache[..., mrope1_slice]
self.cos_sin_cache_mrope2[..., mrope2_slice] = self.cos_sin_cache[..., mrope2_slice]
self.cos_sin_cache_mrope0 = self.cos_sin_cache.clone()
self.cos_sin_cache_mrope0[..., mrope1_slice] = 0
self.cos_sin_cache_mrope0[..., mrope2_slice] = 0

def repeat_cache(cos_sin_cache: torch.Tensor) -> torch.Tensor:
if self.is_neox_style:
cos, sin = cos_sin_cache.chunk(2, dim=-1)
return torch.cat((cos, cos, sin, sin), dim=-1)
return torch.repeat_interleave(cos_sin_cache, 2, dim=-1)

self.cos_sin_cache_mrope0 = repeat_cache(self.cos_sin_cache_mrope0)
self.cos_sin_cache_mrope1 = repeat_cache(self.cos_sin_cache_mrope1)
self.cos_sin_cache_mrope2 = repeat_cache(self.cos_sin_cache_mrope2)
Comment thread vllm_gaudi/ops/hpu_rotary_embedding.py Outdated
Comment on lines +656 to +673
self.cos_sin_cache_mrope1 = torch.zeros_like(self.cos_sin_cache)
self.cos_sin_cache_mrope2 = torch.zeros_like(self.cos_sin_cache)
self.cos_sin_cache_mrope1[..., mrope1_slice] = self.cos_sin_cache[..., mrope1_slice]
self.cos_sin_cache_mrope2[..., mrope2_slice] = self.cos_sin_cache[..., mrope2_slice]
self.cos_sin_cache_mrope0 = self.cos_sin_cache.clone()
self.cos_sin_cache_mrope0[..., mrope1_slice] = 0
self.cos_sin_cache_mrope0[..., mrope2_slice] = 0

def repeat_cache(cos_sin_cache: torch.Tensor) -> torch.Tensor:
if self.is_neox_style:
cos, sin = cos_sin_cache.chunk(2, dim=-1)
return torch.cat((cos, cos, sin, sin), dim=-1)
return torch.repeat_interleave(cos_sin_cache, 2, dim=-1)

self.cos_sin_cache_mrope0 = repeat_cache(self.cos_sin_cache_mrope0)
self.cos_sin_cache_mrope1 = repeat_cache(self.cos_sin_cache_mrope1)
self.cos_sin_cache_mrope2 = repeat_cache(self.cos_sin_cache_mrope2)
self._mrope_hpu_cache_prepared = True
Comment thread vllm_gaudi/ops/hpu_rotary_embedding.py Outdated
Comment on lines +718 to +723
use_mrope_cache_sum = positions.ndim == 2 and positions.shape[0] == 3
if use_mrope_cache_sum:
if not getattr(self, "_mrope_hpu_cache_prepared", False):
self.prepare_cos_sin(positions, offsets, prepare_mrope_cache=True)
cos_sin = (self.cos_sin_cache_mrope0[positions[0]] + self.cos_sin_cache_mrope1[positions[1]] +
self.cos_sin_cache_mrope2[positions[2]])
@hsubramony hsubramony marked this pull request as draft May 11, 2026 23:25
@github-actions
Copy link
Copy Markdown

🚧 CI Blocked

The main CI workflow was not started for the following reason:

This is a Draft PR. Please mark it as 'Ready for Review' to trigger the CI.

@hsubramony hsubramony marked this pull request as ready for review May 11, 2026 23:44
@github-actions
Copy link
Copy Markdown

✅ CI Passed

All checks passed successfully against the following vllm commit:
54f548e9e58087f0155e4e164e416ad7efdfde6d

@github-actions
Copy link
Copy Markdown

✅ CI Passed

All checks passed successfully against the following vllm commit:
54f548e9e58087f0155e4e164e416ad7efdfde6d

@adobrzyn adobrzyn merged commit 7a4d2fe into vllm-project:main May 15, 2026
2 checks passed
iboiko-habana pushed a commit that referenced this pull request May 18, 2026
When mrope_interleaved is enabled, HPUMRotaryEmbedding was still using
the non-interleaved split/concat section mapping for cos/sin.
This produced incorrect rotary channel ordering for multimodal MRoPE
inputs and could cause sample-level mismatches against upstream vLLM
behavior.
Use apply_interleaved_rope for the interleaved branch, and preserve the
existing split/concat logic for non-interleaved layouts.

Signed-off-by: Harish Subramony <harish.subramony@intel.com>
Co-authored-by: Jimin Ha <jimin.ha@intel.com>
Co-authored-by: Agata Dobrzyniewicz <160237065+adobrzyn@users.noreply.github.com>
Co-authored-by: Seunghyuk Park (shepark) <separk@habana.ai>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants