From eaafb1ec1960b12e61fe2d9375c769b0ba192757 Mon Sep 17 00:00:00 2001 From: zhangqiu Date: Tue, 2 Jul 2024 09:01:11 +0000 Subject: [PATCH 1/2] add apply_rotary --- deeplink_ext/easyllm_ops/__init__.py | 14 ++++ deeplink_ext/easyllm_ops/rotary_embedding.py | 77 ++++++++++++++++++++ 2 files changed, 91 insertions(+) create mode 100644 deeplink_ext/easyllm_ops/__init__.py create mode 100644 deeplink_ext/easyllm_ops/rotary_embedding.py diff --git a/deeplink_ext/easyllm_ops/__init__.py b/deeplink_ext/easyllm_ops/__init__.py new file mode 100644 index 0000000..2c19f47 --- /dev/null +++ b/deeplink_ext/easyllm_ops/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2024, DeepLink. + +_not_impl = "[deeplink_ext] {op_name} is not implemented in diopi. Falling back to the slower torch implementation." + +try: + from .rotary_embedding import apply_rotary + +except Exception as e: + print(_not_impl.format(op_name="rotary_embedding")) + print("Rotary Embedding currently does not support fallback!") + +__all__ = [ + "apply_rotary", +] diff --git a/deeplink_ext/easyllm_ops/rotary_embedding.py b/deeplink_ext/easyllm_ops/rotary_embedding.py new file mode 100644 index 0000000..96e14a0 --- /dev/null +++ b/deeplink_ext/easyllm_ops/rotary_embedding.py @@ -0,0 +1,77 @@ +# Copyright (c) 2024, DeepLink. + +import torch +import deeplink_ext.cpp_extensions as ext + + +def apply_rotary( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + interleaved=False, + inplace=False, + conjugate=False, +) -> torch.Tensor: + """ + Arguments: + x: (batch, seqlen, nheads, headdim) if cu_seqlens is None + else (total_seqlen, nheads, headdim). + cos: (seqlen_ro, rotary_dim / 2) + sin: (seqlen_ro, rotary_dim / 2) + seqlen_offsets: integer or integer tensor of size (batch,) + cu_seqlens: (batch + 1,) or None + max_seqlen: int + Returns: + y: (batch, seqlen, nheads, headdim) + """ + # 如果使用累积序列长度cu_seqlens,则使用变长模式,但目前由于kernel的限制,暂无设备支持变长模式 + is_varlen = cu_seqlens is not None + assert not is_varlen, "varlen mode rotary embedding not supported yet." + if not is_varlen: + batch, seqlen, nheads, headdim = x.shape + else: + assert ( + max_seqlen is not None + ), "If cu_seqlens is passed in, then max_seqlen must be passed" + total_seqlen, nheads, headdim = x.shape + batch_p_1 = cu_seqlens.shape[0] + batch = batch_p_1 - 1 + seqlen = max_seqlen + seqlen_ro, rotary_dim = cos.shape + assert sin.shape == cos.shape + rotary_dim *= 2 + assert rotary_dim <= headdim, "rotary_dim must be <= headdim" + assert headdim <= 256, "Only support headdim <= 256" + assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen" + + assert ( + cos.dtype == sin.dtype + ), f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}" + assert ( + x.dtype == cos.dtype + ), f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}" + + cos, sin = cos.contiguous(), sin.contiguous() + # if isinstance(seqlen_offsets, torch.Tensor): + # assert seqlen_offsets.shape == (batch,) + # assert seqlen_offsets.dtype in [torch.int32, torch.int64] + # seqlen_offsets = seqlen_offsets.contiguous() + # else: + # assert seqlen_offsets + seqlen <= seqlen_ro + + output = torch.empty_like(x) if not inplace else x + if rotary_dim < headdim: + if not inplace: + output[..., rotary_dim:].copy_(x[..., rotary_dim:]) + ext.apply_rotary( + output[..., :rotary_dim], + x[..., :rotary_dim], + cos, + sin, + conjugate, + interleaved, + ) + return output From 5c8b28e3878339551370092b3a4a0aca687cc2d6 Mon Sep 17 00:00:00 2001 From: zhangqiu Date: Tue, 2 Jul 2024 09:04:13 +0000 Subject: [PATCH 2/2] fix --- deeplink_ext/easyllm_ops/rotary_embedding.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/deeplink_ext/easyllm_ops/rotary_embedding.py b/deeplink_ext/easyllm_ops/rotary_embedding.py index 96e14a0..d76f262 100644 --- a/deeplink_ext/easyllm_ops/rotary_embedding.py +++ b/deeplink_ext/easyllm_ops/rotary_embedding.py @@ -74,4 +74,13 @@ def apply_rotary( conjugate, interleaved, ) + else: + ext.apply_rotary( + output, + x, + cos, + sin, + conjugate, + interleaved, + ) return output