diff --git a/vllm_gaudi/ops/hpu_rotary_embedding.py b/vllm_gaudi/ops/hpu_rotary_embedding.py index 38d02d23f9..cfc452927a 100644 --- a/vllm_gaudi/ops/hpu_rotary_embedding.py +++ b/vllm_gaudi/ops/hpu_rotary_embedding.py @@ -668,9 +668,14 @@ def forward_oot( cos, sin = cos_sin.chunk(2, dim=-1) if positions.ndim == 2: assert self.mrope_section - - cos = torch.cat([m[i] for i, m in enumerate(cos.split(self.mrope_section, dim=-1))], dim=-1) - sin = torch.cat([m[i] for i, m in enumerate(sin.split(self.mrope_section, dim=-1))], dim=-1) + if getattr(self, "mrope_interleaved", False): + from vllm.model_executor.layers.rotary_embedding.mrope import apply_interleaved_rope + + cos = apply_interleaved_rope(cos, self.mrope_section) + sin = apply_interleaved_rope(sin, self.mrope_section) + else: + cos = torch.cat([m[i] for i, m in enumerate(cos.split(self.mrope_section, dim=-1))], dim=-1) + sin = torch.cat([m[i] for i, m in enumerate(sin.split(self.mrope_section, dim=-1))], dim=-1) if self.is_neox_style: cos = torch.cat((cos, cos), dim=-1).unsqueeze(-2) sin = torch.cat((sin, sin), dim=-1).unsqueeze(-2)