diff --git a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py index 02ef4e2440cd..6bfcb3f66852 100644 --- a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py +++ b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py @@ -257,13 +257,13 @@ def _weights_proj_bf16_in_fp32_out(self, x: torch.Tensor) -> torch.Tensor: weights, _ = self.weights_proj(x) return weights.float() - @torch.compile(dynamic=True) if not _is_hip else lambda f: f + @torch.compile(dynamic=True) def _project_and_scale_head_gates(self, x: torch.Tensor): weights = self._weights_proj_bf16_in_fp32_out(x) weights = weights * self.n_heads**-0.5 return weights - @torch.compile(dynamic=True) if not _is_hip else lambda f: f + @torch.compile(dynamic=True) def _get_logits_head_gate(self, x: torch.Tensor, q_scale: torch.Tensor): weights = self._weights_proj_bf16_in_fp32_out(x) weights = weights * self.n_heads**-0.5 @@ -318,8 +318,8 @@ def _get_q_k_bf16( q_rope, k_rope = self.rotary_emb(positions, q_rope, k_rope) - query[..., : self.rope_head_dim] = q_rope.clone() - key[..., : self.rope_head_dim] = k_rope.clone() + self._update_rope_guarded(query[..., : self.rope_head_dim], q_rope) + self._update_rope_guarded(key[..., : self.rope_head_dim], k_rope) if enable_dual_stream: current_stream = torch.cuda.current_stream() @@ -376,11 +376,19 @@ def _get_k_bf16( ) _, k_rope = self.rotary_emb(positions, k_rope, k_rope) - key[..., : self.rope_head_dim] = k_rope.clone() + self._update_rope_guarded(key[..., : self.rope_head_dim], k_rope) key = rotate_activation(key) return key + @staticmethod + def _update_rope_guarded(dst: torch.Tensor, src: torch.Tensor) -> None: + # On AMD with in-place RoPE kernels, self-aliasing can occur; + # skip write-back when src/dst tensors point to a single memory. + if src.data_ptr() == dst.data_ptr(): + return + dst.copy_(src) + def _get_topk_paged( self, forward_batch: ForwardBatch,