-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add long sequence strategies #8076
Merged
Merged
Changes from 28 commits
Commits
Show all changes
30 commits
Select commit
Hold shift + click to select a range
cd8b643
test
WAI-clear 58dde75
remove test.txt
WAI-clear 58b66a2
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
WAI-clear c5acdc2
add long sequence strategies
WAI-clear c0c950a
long sequence V1
WAI-clear 4d2b599
draft
WAI-clear fbfe654
draft
WAI-clear fa66a5f
draft new
WAI-clear 4e2ffc9
add long sequence stratiges
WAI-clear 6bf1282
add long sequence strategies new
WAI-clear 4aa5917
fix format
WAI-clear c46527e
fix conflict
WAI-clear c202d89
fix format
WAI-clear fe57746
fix error
WAI-clear 1481944
fix error
WAI-clear 56e014a
fix format
WAI-clear d245ad4
fix format
WAI-clear d51de39
fix format
WAI-clear dada723
fix format
WAI-clear 32033f3
close @slow
WAI-clear e1b0ff9
fix test
WAI-clear dcb712f
fix error
WAI-clear a00fcee
modify try_catch
WAI-clear d88b4f4
fix format
WAI-clear 386bca6
fix format
WAI-clear 506b2cf
fix format
WAI-clear 7176f9f
add bloom_alibi
WAI-clear 0c788ad
fix error
WAI-clear ef02a24
add dynamic_to_static
WAI-clear dc8da0a
add dynamic_to_static
WAI-clear File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
49 changes: 49 additions & 0 deletions
49
paddlenlp/transformers/LongSequenceStrategies/AttentionStrategies.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import math | ||
|
||
import paddle | ||
from paddle import Tensor, nn | ||
|
||
__all__ = ["AttentionWithLinearBias"] | ||
|
||
|
||
class AttentionWithLinearBias(nn.Layer): | ||
def __init__(self, **init_args): | ||
super().__init__() | ||
|
||
def _get_interleave(self, n): | ||
def _get_interleave_power_of_2(n): | ||
start = 2 ** (-(2 ** -(math.log2(n) - 3))) | ||
return [start * start**i for i in range(n)] | ||
|
||
if math.log2(n).is_integer(): | ||
return _get_interleave_power_of_2(n) | ||
else: | ||
closest_power_of_2 = 2 ** math.floor(math.log2(n)) | ||
return ( | ||
_get_interleave_power_of_2(closest_power_of_2) | ||
+ self._get_interleave(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] | ||
) | ||
|
||
def forward(self, bool_attention_mask: Tensor, num_heads: int, dtype: paddle.dtype): | ||
attention_mask = bool_attention_mask.astype("float32") | ||
batch_size, seq_length = attention_mask.shape[0], attention_mask.shape[-1] | ||
slopes = paddle.to_tensor(self._get_interleave(num_heads), dtype="float32") | ||
alibi = slopes.unsqueeze(axis=[1, 2]) * paddle.arange(seq_length, dtype="float32").unsqueeze( | ||
axis=[0, 1] | ||
).expand([num_heads, -1, -1]) | ||
alibi = alibi.reshape(shape=(1, num_heads, 1, seq_length)).expand([batch_size, -1, -1, -1]) | ||
return paddle.cast(alibi, dtype) |
122 changes: 122 additions & 0 deletions
122
paddlenlp/transformers/LongSequenceStrategies/EmbeddingStrategies.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import paddle | ||
from paddle import nn | ||
|
||
__all__ = [ | ||
"RotaryEmbedding", | ||
"LinearScalingRotaryEmbedding", | ||
"NTKScalingRotaryEmbedding", | ||
"DynamicNTKScalingRotaryEmbedding", | ||
] | ||
|
||
|
||
class RotaryEmbedding(nn.Layer): | ||
def __init__(self, **init_args): | ||
super().__init__() | ||
self.dim = init_args["dim"] | ||
self.max_position_embeddings = init_args["max_position_embeddings"] | ||
self.base = init_args["base"] | ||
self.position_encoding_2d = init_args["position_encoding_2d"] if "position_encoding_2d" in init_args else False | ||
if self.position_encoding_2d: | ||
# [dim / 4]# 2D--Embedding | ||
self.dim = self.dim / 2 | ||
inv_freq = 1.0 / ( | ||
self.base ** (paddle.cast(paddle.arange(0, self.dim, 2), dtype=paddle.float32) / self.dim) | ||
) | ||
else: | ||
# [dim / 2] | ||
inv_freq = 1.0 / ( | ||
self.base ** (paddle.cast(paddle.arange(0, self.dim, 2), dtype=paddle.float32) / self.dim) | ||
) | ||
self.register_buffer("inv_freq", inv_freq) | ||
self._set_cos_sin_cache(seq_len=self.max_position_embeddings) | ||
|
||
def _set_cos_sin_cache(self, seq_len): | ||
self.max_seq_len_cached = seq_len | ||
# [seq_len] | ||
t = paddle.arange(seq_len, dtype=paddle.float32) | ||
# [seq_len, dim/2] | ||
with paddle.amp.auto_cast(enable=False): | ||
freqs = paddle.outer(t.astype(self.inv_freq.dtype), self.inv_freq) | ||
# [seq_len, dim] | ||
emb = paddle.concat([freqs, freqs], axis=-1) | ||
self.cos_cached = emb.cos()[:, :] | ||
self.sin_cached = emb.sin()[:, :] | ||
|
||
def forward(self, seq_len=None, ntk_alpha=None): | ||
|
||
return self.cos_cached[:, :], self.sin_cached[:, :] | ||
|
||
|
||
class LinearScalingRotaryEmbedding(RotaryEmbedding): | ||
def __init__(self, **init_args): | ||
self.scaling_factor = init_args["scaling_factor"] | ||
super().__init__(**init_args) | ||
|
||
def _set_cos_sin_cache(self, seq_len): | ||
self.max_seq_len_cached = seq_len | ||
# [seq_len] | ||
t = paddle.arange(seq_len, dtype=paddle.float32) | ||
t = t / self.scaling_factor | ||
# [seq_len, dim/2] | ||
with paddle.amp.auto_cast(enable=False): | ||
freqs = paddle.outer(t.astype(self.inv_freq.dtype), self.inv_freq) | ||
# [seq_len, dim] | ||
emb = paddle.concat([freqs, freqs], axis=-1) | ||
self.cos_cached = emb.cos()[:, :] | ||
self.sin_cached = emb.sin()[:, :] | ||
|
||
|
||
class NTKScalingRotaryEmbedding(RotaryEmbedding): | ||
"""RotaryEmbedding extended with NTK scaling. https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/""" | ||
|
||
def __init__(self, **init_args): | ||
init_args["base"] = init_args["base"] * init_args["scaling_factor"] ** ( | ||
init_args["dim"] / (init_args["dim"] - 2) | ||
) | ||
super().__init__(**init_args) | ||
|
||
|
||
class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding): | ||
"""RotaryEmbedding extended with Dynamic NTK scaling. https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/""" | ||
|
||
def __init__(self, **init_args): | ||
self.scaling_factor = init_args["scaling_factor"] | ||
self._seq_len_cached = 0 | ||
super().__init__(**init_args) | ||
|
||
def _scale_cos_sin(self, seq_len, ntk_alpha=None): | ||
# [seq_len] | ||
t = paddle.arange(seq_len, dtype=paddle.float32) | ||
if ntk_alpha is None: | ||
ntk_alpha = (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) | ||
base = self.base * ntk_alpha ** (self.dim / (self.dim - 2)) | ||
|
||
# [seq_len, dim/2] | ||
inv_freq = 1.0 / (base ** (paddle.cast(paddle.arange(0, self.dim, 2), dtype=paddle.float32) / self.dim)) | ||
with paddle.amp.auto_cast(enable=False): | ||
freqs = paddle.outer(t.astype(inv_freq.dtype), inv_freq) | ||
# [seq_len, dim] | ||
emb = paddle.concat([freqs, freqs], axis=-1) | ||
self.cos_cached = emb.cos()[:, :] | ||
self.sin_cached = emb.sin()[:, :] | ||
|
||
def forward(self, seq_len=None, ntk_alpha=None): | ||
|
||
if seq_len > self.max_position_embeddings: | ||
self._scale_cos_sin(seq_len=seq_len, ntk_alpha=ntk_alpha) | ||
|
||
return super().forward(seq_len) |
66 changes: 66 additions & 0 deletions
66
paddlenlp/transformers/LongSequenceStrategies/LongSequenceStrategies.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import importlib | ||
|
||
all_strategy_types = ["EmbeddingStrategies", "AttentionStrategies"] | ||
|
||
|
||
class LongSequenceStrategies: | ||
@classmethod | ||
def build_long_sequence_strategy(cls, strategy_type=None, stratety_name=None, **init_args): | ||
""" | ||
|
||
**init_args: head_dim, | ||
max_position_embeddings, | ||
rope_scaling_type, | ||
rope_scaling_factor, | ||
... | ||
|
||
strategy_type: "None" ---------------走原始的build-in模块 | ||
"EmbeddingStrategies"、 | ||
"AttentionStrategies" | ||
... | ||
|
||
stratety_name: "RotaryEmbedding"、 | ||
"LinearScalingRotaryEmbedding"、 | ||
"NTKScalingRotaryEmbedding"、 | ||
"DynamicNTKScalingRotaryEmbedding"、 | ||
"AttentionWithLinearBias" | ||
... | ||
|
||
""" | ||
|
||
""" | ||
paddlenlp.transformers.LongSequenceStrategies.{strategy_type<->import_class)}.{stratety_name<->strategy_class)} | ||
paddlenlp.transformers.LongSequenceStrategies.{EmbeddingStrategies}.{RoPE,...} | ||
paddlenlp.transformers.LongSequenceStrategies.{AttentionStrategies}.{ALiBi,...} | ||
""" | ||
try: | ||
import_class = importlib.import_module(f"paddlenlp.transformers.LongSequenceStrategies.{strategy_type}") | ||
except ModuleNotFoundError: | ||
raise ModuleNotFoundError( | ||
f"Wrong strategy type {strategy_type}. module only supports the following types: " | ||
+ ", ".join(m for m in all_strategy_types) | ||
) | ||
try: | ||
strategy_class = getattr(import_class, stratety_name) | ||
except: | ||
all_strategy_classes = import_class.__all__ | ||
raise LookupError( | ||
f"module '{import_class.__name__}' only supports the following classes: " | ||
+ ", ".join(m for m in all_strategy_classes) | ||
) | ||
strategy_instance = strategy_class(**init_args) | ||
return strategy_instance |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
from .AttentionStrategies import * | ||
from .EmbeddingStrategies import * | ||
from .LongSequenceStrategies import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
文件命名,目录命名,小写