Skip to content

Commit

Permalink
[DeepSeek] Fix some bugs for dsk-v3 (#9874)
Browse files Browse the repository at this point in the history
* use fused rope

* fix import
  • Loading branch information
yuanlehome authored Feb 24, 2025
1 parent 0b26a02 commit 30df8b6
Showing 1 changed file with 13 additions and 27 deletions.
40 changes: 13 additions & 27 deletions paddlenlp/experimental/transformers/deepseek_v2/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,10 @@ def __init__(
* attn_factor
)

cos_cache, sin_cache = self._compute_cos_sin_cache()
cache = self._compute_cos_sin_cache()

self.cos_cache: paddle.Tensor
self.register_buffer("cos_cache", cos_cache, persistable=True)
self.sin_cache: paddle.Tensor
self.register_buffer("sin_cache", sin_cache, persistable=True)
self.cos_sin_cache: paddle.Tensor
self.register_buffer("cos_sin_cache", cache, persistable=True)

def _compute_inv_freq(self, scaling_factor: float) -> paddle.Tensor:
pos_freqs = self.base ** (paddle.arange(0, self.rotary_dim, 2, dtype=paddle.float32) / self.rotary_dim)
Expand All @@ -116,37 +114,25 @@ def _compute_inv_freq(self, scaling_factor: float) -> paddle.Tensor:
def _compute_cos_sin_cache(self) -> paddle.Tensor:
inv_freq = self._compute_inv_freq(self.scaling_factor)
t = paddle.arange(self.max_position_embeddings * self.scaling_factor, dtype=paddle.float32)

freqs = paddle.outer(t, inv_freq)
emb = paddle.concat((freqs, freqs), axis=-1)
cos = emb.cos() * self.mscale
sin = emb.sin() * self.mscale

return cos.cast(self._dtype), sin.cast(self._dtype)
freqs = paddle.einsum("i,j->ij", t, inv_freq)
cos = freqs.cos() * self.mscale
sin = freqs.sin() * self.mscale
cache = paddle.concat((cos, sin), axis=-1)
return cache.cast(self._dtype)

def forward(
self,
position_ids: paddle.Tensor,
query: paddle.Tensor,
key: paddle.Tensor,
) -> Tuple[paddle.Tensor, paddle.Tensor]:
cos = self.cos_cache[position_ids].unsqueeze(1)
sin = self.sin_cache[position_ids].unsqueeze(1)

def rotate_half(x):
"""Rotates half the hidden axiss of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return paddle.concat([-x2, x1], axis=-1) # shape is the same as x

s, h, d = query.shape
query = query.reshape([s, h, d // 2, 2]).transpose([0, 1, 3, 2]).reshape([s, h, d])
import os

s, h, d = key.shape
key = key.reshape([s, h, d // 2, 2]).transpose([0, 1, 3, 2]).reshape([s, h, d])
from paddlenlp_ops import fused_rotary_position_encoding

query = (query * cos) + (rotate_half(query) * sin)
key = (key * cos) + (rotate_half(key) * sin)
# In-place operations that update the query and key tensors.
os.environ["stride_in_no_check_dy2st_diff"] = "1"
fused_rotary_position_encoding(query, key, position_ids, self.cos_sin_cache, self.rotary_dim, False)

return query, key

Expand Down

0 comments on commit 30df8b6

Please sign in to comment.