From 30df8b6b6ae6428e53e463ba089c05e2833e4683 Mon Sep 17 00:00:00 2001 From: Yuanle Liu Date: Mon, 24 Feb 2025 11:58:17 +0800 Subject: [PATCH] [DeepSeek] Fix some bugs for dsk-v3 (#9874) * use fused rope * fix import --- .../transformers/deepseek_v2/modeling.py | 40 ++++++------------- 1 file changed, 13 insertions(+), 27 deletions(-) diff --git a/paddlenlp/experimental/transformers/deepseek_v2/modeling.py b/paddlenlp/experimental/transformers/deepseek_v2/modeling.py index 6d1364857ae5..037208f4a151 100644 --- a/paddlenlp/experimental/transformers/deepseek_v2/modeling.py +++ b/paddlenlp/experimental/transformers/deepseek_v2/modeling.py @@ -92,12 +92,10 @@ def __init__( * attn_factor ) - cos_cache, sin_cache = self._compute_cos_sin_cache() + cache = self._compute_cos_sin_cache() - self.cos_cache: paddle.Tensor - self.register_buffer("cos_cache", cos_cache, persistable=True) - self.sin_cache: paddle.Tensor - self.register_buffer("sin_cache", sin_cache, persistable=True) + self.cos_sin_cache: paddle.Tensor + self.register_buffer("cos_sin_cache", cache, persistable=True) def _compute_inv_freq(self, scaling_factor: float) -> paddle.Tensor: pos_freqs = self.base ** (paddle.arange(0, self.rotary_dim, 2, dtype=paddle.float32) / self.rotary_dim) @@ -116,13 +114,11 @@ def _compute_inv_freq(self, scaling_factor: float) -> paddle.Tensor: def _compute_cos_sin_cache(self) -> paddle.Tensor: inv_freq = self._compute_inv_freq(self.scaling_factor) t = paddle.arange(self.max_position_embeddings * self.scaling_factor, dtype=paddle.float32) - - freqs = paddle.outer(t, inv_freq) - emb = paddle.concat((freqs, freqs), axis=-1) - cos = emb.cos() * self.mscale - sin = emb.sin() * self.mscale - - return cos.cast(self._dtype), sin.cast(self._dtype) + freqs = paddle.einsum("i,j->ij", t, inv_freq) + cos = freqs.cos() * self.mscale + sin = freqs.sin() * self.mscale + cache = paddle.concat((cos, sin), axis=-1) + return cache.cast(self._dtype) def forward( self, @@ -130,23 +126,13 @@ def forward( query: paddle.Tensor, key: paddle.Tensor, ) -> Tuple[paddle.Tensor, paddle.Tensor]: - cos = self.cos_cache[position_ids].unsqueeze(1) - sin = self.sin_cache[position_ids].unsqueeze(1) - - def rotate_half(x): - """Rotates half the hidden axiss of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return paddle.concat([-x2, x1], axis=-1) # shape is the same as x - - s, h, d = query.shape - query = query.reshape([s, h, d // 2, 2]).transpose([0, 1, 3, 2]).reshape([s, h, d]) + import os - s, h, d = key.shape - key = key.reshape([s, h, d // 2, 2]).transpose([0, 1, 3, 2]).reshape([s, h, d]) + from paddlenlp_ops import fused_rotary_position_encoding - query = (query * cos) + (rotate_half(query) * sin) - key = (key * cos) + (rotate_half(key) * sin) + # In-place operations that update the query and key tensors. + os.environ["stride_in_no_check_dy2st_diff"] = "1" + fused_rotary_position_encoding(query, key, position_ids, self.cos_sin_cache, self.rotary_dim, False) return query, key