Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add long sequence strategies #8076

Merged
merged 30 commits into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

文件命名,目录命名,小写

#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math

import paddle
from paddle import Tensor, nn

__all__ = ["AttentionWithLinearBias"]


class AttentionWithLinearBias(nn.Layer):
def __init__(self, **init_args):
super().__init__()

def _get_interleave(self, n):
def _get_interleave_power_of_2(n):
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
return [start * start**i for i in range(n)]

if math.log2(n).is_integer():
return _get_interleave_power_of_2(n)
else:
closest_power_of_2 = 2 ** math.floor(math.log2(n))
return (

Check warning on line 36 in paddlenlp/transformers/LongSequenceStrategies/AttentionStrategies.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/LongSequenceStrategies/AttentionStrategies.py#L35-L36

Added lines #L35 - L36 were not covered by tests
_get_interleave_power_of_2(closest_power_of_2)
+ self._get_interleave(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
)

def forward(self, bool_attention_mask: Tensor, num_heads: int, dtype: paddle.dtype):
attention_mask = bool_attention_mask.astype("float32")
batch_size, seq_length = attention_mask.shape[0], attention_mask.shape[-1]
slopes = paddle.to_tensor(self._get_interleave(num_heads), dtype="float32")
alibi = slopes.unsqueeze(axis=[1, 2]) * paddle.arange(seq_length, dtype="float32").unsqueeze(
axis=[0, 1]
).expand([num_heads, -1, -1])
alibi = alibi.reshape(shape=(1, num_heads, 1, seq_length)).expand([batch_size, -1, -1, -1])
return paddle.cast(alibi, dtype)
122 changes: 122 additions & 0 deletions paddlenlp/transformers/LongSequenceStrategies/EmbeddingStrategies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
from paddle import nn

__all__ = [
"RotaryEmbedding",
"LinearScalingRotaryEmbedding",
"NTKScalingRotaryEmbedding",
"DynamicNTKScalingRotaryEmbedding",
]


class RotaryEmbedding(nn.Layer):
def __init__(self, **init_args):
super().__init__()
self.dim = init_args["dim"]
self.max_position_embeddings = init_args["max_position_embeddings"]
self.base = init_args["base"]
self.position_encoding_2d = init_args["position_encoding_2d"] if "position_encoding_2d" in init_args else False
if self.position_encoding_2d:
# [dim / 4]# 2D--Embedding
self.dim = self.dim / 2
inv_freq = 1.0 / (
self.base ** (paddle.cast(paddle.arange(0, self.dim, 2), dtype=paddle.float32) / self.dim)
)
else:
# [dim / 2]
inv_freq = 1.0 / (
self.base ** (paddle.cast(paddle.arange(0, self.dim, 2), dtype=paddle.float32) / self.dim)
)
self.register_buffer("inv_freq", inv_freq)
self._set_cos_sin_cache(seq_len=self.max_position_embeddings)

def _set_cos_sin_cache(self, seq_len):
self.max_seq_len_cached = seq_len
# [seq_len]
t = paddle.arange(seq_len, dtype=paddle.float32)
# [seq_len, dim/2]
with paddle.amp.auto_cast(enable=False):
freqs = paddle.outer(t.astype(self.inv_freq.dtype), self.inv_freq)
# [seq_len, dim]
emb = paddle.concat([freqs, freqs], axis=-1)
self.cos_cached = emb.cos()[:, :]
self.sin_cached = emb.sin()[:, :]

def forward(self, seq_len=None, ntk_alpha=None):

return self.cos_cached[:, :], self.sin_cached[:, :]


class LinearScalingRotaryEmbedding(RotaryEmbedding):
def __init__(self, **init_args):
self.scaling_factor = init_args["scaling_factor"]
super().__init__(**init_args)

def _set_cos_sin_cache(self, seq_len):
self.max_seq_len_cached = seq_len
# [seq_len]
t = paddle.arange(seq_len, dtype=paddle.float32)
t = t / self.scaling_factor
# [seq_len, dim/2]
with paddle.amp.auto_cast(enable=False):
freqs = paddle.outer(t.astype(self.inv_freq.dtype), self.inv_freq)
# [seq_len, dim]
emb = paddle.concat([freqs, freqs], axis=-1)
self.cos_cached = emb.cos()[:, :]
self.sin_cached = emb.sin()[:, :]


class NTKScalingRotaryEmbedding(RotaryEmbedding):
"""RotaryEmbedding extended with NTK scaling. https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/"""

def __init__(self, **init_args):
init_args["base"] = init_args["base"] * init_args["scaling_factor"] ** (
init_args["dim"] / (init_args["dim"] - 2)
)
super().__init__(**init_args)


class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
"""RotaryEmbedding extended with Dynamic NTK scaling. https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/"""

def __init__(self, **init_args):
self.scaling_factor = init_args["scaling_factor"]
self._seq_len_cached = 0
super().__init__(**init_args)

def _scale_cos_sin(self, seq_len, ntk_alpha=None):
# [seq_len]
t = paddle.arange(seq_len, dtype=paddle.float32)
if ntk_alpha is None:
ntk_alpha = (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
base = self.base * ntk_alpha ** (self.dim / (self.dim - 2))

# [seq_len, dim/2]
inv_freq = 1.0 / (base ** (paddle.cast(paddle.arange(0, self.dim, 2), dtype=paddle.float32) / self.dim))
with paddle.amp.auto_cast(enable=False):
freqs = paddle.outer(t.astype(inv_freq.dtype), inv_freq)
# [seq_len, dim]
emb = paddle.concat([freqs, freqs], axis=-1)
self.cos_cached = emb.cos()[:, :]
self.sin_cached = emb.sin()[:, :]

def forward(self, seq_len=None, ntk_alpha=None):

if seq_len > self.max_position_embeddings:
self._scale_cos_sin(seq_len=seq_len, ntk_alpha=ntk_alpha)

return super().forward(seq_len)
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import importlib

all_strategy_types = ["EmbeddingStrategies", "AttentionStrategies"]


class LongSequenceStrategies:
@classmethod
def build_long_sequence_strategy(cls, strategy_type=None, stratety_name=None, **init_args):
"""

**init_args: head_dim,
max_position_embeddings,
rope_scaling_type,
rope_scaling_factor,
...

strategy_type: "None" ---------------走原始的build-in模块
"EmbeddingStrategies"、
"AttentionStrategies"
...

stratety_name: "RotaryEmbedding"、
"LinearScalingRotaryEmbedding"、
"NTKScalingRotaryEmbedding"、
"DynamicNTKScalingRotaryEmbedding"、
"AttentionWithLinearBias"
...

"""

"""
paddlenlp.transformers.LongSequenceStrategies.{strategy_type<->import_class)}.{stratety_name<->strategy_class)}
paddlenlp.transformers.LongSequenceStrategies.{EmbeddingStrategies}.{RoPE,...}
paddlenlp.transformers.LongSequenceStrategies.{AttentionStrategies}.{ALiBi,...}
"""
try:
import_class = importlib.import_module(f"paddlenlp.transformers.LongSequenceStrategies.{strategy_type}")
except ModuleNotFoundError:
raise ModuleNotFoundError(

Check warning on line 53 in paddlenlp/transformers/LongSequenceStrategies/LongSequenceStrategies.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/LongSequenceStrategies/LongSequenceStrategies.py#L52-L53

Added lines #L52 - L53 were not covered by tests
f"Wrong strategy type {strategy_type}. module only supports the following types: "
+ ", ".join(m for m in all_strategy_types)
)
try:
strategy_class = getattr(import_class, stratety_name)
except:
all_strategy_classes = import_class.__all__
raise LookupError(

Check warning on line 61 in paddlenlp/transformers/LongSequenceStrategies/LongSequenceStrategies.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/LongSequenceStrategies/LongSequenceStrategies.py#L59-L61

Added lines #L59 - L61 were not covered by tests
f"module '{import_class.__name__}' only supports the following classes: "
+ ", ".join(m for m in all_strategy_classes)
)
strategy_instance = strategy_class(**init_args)
return strategy_instance
18 changes: 18 additions & 0 deletions paddlenlp/transformers/LongSequenceStrategies/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from .AttentionStrategies import *
from .EmbeddingStrategies import *
from .LongSequenceStrategies import *
9 changes: 9 additions & 0 deletions paddlenlp/transformers/bloom/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ def __init__(
use_recompute=False,
use_pure_fp16=False,
use_flash_attention=False,
long_sequence_strategy_type=None,
long_sequence_strategy_name=None,
long_sequence_init_args=None,
use_long_sequence_strategies=False,
**kwargs,
):

Expand All @@ -150,3 +154,8 @@ def __init__(
self.use_recompute = use_recompute
self.use_pure_fp16 = use_pure_fp16
self.use_flash_attention = use_flash_attention

self.long_sequence_strategy_type = long_sequence_strategy_type
self.long_sequence_strategy_name = long_sequence_strategy_name
self.long_sequence_init_args = {} if long_sequence_init_args is None else long_sequence_init_args
self.use_long_sequence_strategies = use_long_sequence_strategies
24 changes: 21 additions & 3 deletions paddlenlp/transformers/bloom/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from paddle.distributed import fleet
from paddle.distributed.fleet.utils import recompute

from paddlenlp.transformers.LongSequenceStrategies import LongSequenceStrategies
from paddlenlp.transformers.model_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
Expand Down Expand Up @@ -944,10 +945,27 @@
attention_mask = paddle.cast(attention_mask, "bool")
if len(attention_mask.shape) > 2:
_attention_mask = paddle.ones([batch_size, seq_length_with_past], dtype="bool")
alibi = build_alibi_tensor(_attention_mask, self.config.n_head, dtype=hidden_states.dtype)
if self.config.use_long_sequence_strategies:
alibi_layer = LongSequenceStrategies.build_long_sequence_strategy(
self.config.long_sequence_strategy_type,
self.config.long_sequence_strategy_name,

Check warning on line 951 in paddlenlp/transformers/bloom/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/bloom/modeling.py#L951

Added line #L951 was not covered by tests
**self.config.long_sequence_init_args,
)
alibi = alibi_layer(_attention_mask, self.config.n_head, dtype=hidden_states.dtype)
alibi = paddle.squeeze(alibi)
else:
alibi = build_alibi_tensor(_attention_mask, self.config.n_head, dtype=hidden_states.dtype)

Check warning on line 957 in paddlenlp/transformers/bloom/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/bloom/modeling.py#L956-L957

Added lines #L956 - L957 were not covered by tests
else:
alibi = build_alibi_tensor(attention_mask, self.config.n_head, dtype=hidden_states.dtype)

if self.config.use_long_sequence_strategies:
alibi_layer = LongSequenceStrategies.build_long_sequence_strategy(
self.config.long_sequence_strategy_type,
self.config.long_sequence_strategy_name,
**self.config.long_sequence_init_args,
)
alibi = alibi_layer(attention_mask, self.config.n_head, dtype=hidden_states.dtype)
alibi = paddle.squeeze(alibi)
else:
alibi = build_alibi_tensor(attention_mask, self.config.n_head, dtype=hidden_states.dtype)
if self.config.tensor_parallel_degree > 1:
block_size = self.config.n_head // self.config.tensor_parallel_degree
alibi = alibi[
Expand Down
9 changes: 8 additions & 1 deletion paddlenlp/transformers/chatglm/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
"CHATGLM_PRETRAINED_RESOURCE_FILES_MAP",
]


CHATGLM_PRETRAINED_RESOURCE_FILES_MAP = {
"model_state": {
"THUDM/chatglm-6b": "https://paddlenlp.bj.bcebos.com/models/community/THUDM/chatglm-6b/model_state.pdparams",
Expand Down Expand Up @@ -104,6 +103,10 @@ def __init__(
activation="gelu",
num_image_tokens=0,
use_flash_attention=False,
long_sequence_strategy_type=None,
long_sequence_strategy_name=None,
long_sequence_init_args=None,
use_long_sequence_strategies=False,
**kwargs
):
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
Expand All @@ -129,3 +132,7 @@ def __init__(
self.activation = activation
self.num_image_tokens = num_image_tokens
self.use_flash_attention = use_flash_attention
self.long_sequence_strategy_type = long_sequence_strategy_type
self.long_sequence_strategy_name = long_sequence_strategy_name
self.long_sequence_init_args = {} if long_sequence_init_args is None else long_sequence_init_args
self.use_long_sequence_strategies = use_long_sequence_strategies
Loading
Loading