-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add long sequence strategies #8076
Changes from 22 commits
cd8b643
58dde75
58b66a2
c5acdc2
c0c950a
4d2b599
fbfe654
fa66a5f
4e2ffc9
6bf1282
4aa5917
c46527e
c202d89
fe57746
1481944
56e014a
d245ad4
d51de39
dada723
32033f3
e1b0ff9
dcb712f
a00fcee
d88b4f4
386bca6
506b2cf
7176f9f
0c788ad
ef02a24
dc8da0a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,32 +1,32 @@ | ||
{ | ||
"model_name_or_path": "facebook/llama-7b", | ||
"dataset_name_or_path": "./data", | ||
"output_dir": "./checkpoints/llama_sft_ckpts", | ||
"per_device_train_batch_size": 4, | ||
"gradient_accumulation_steps": 4, | ||
"per_device_eval_batch_size": 8, | ||
"eval_accumulation_steps":16, | ||
"num_train_epochs": 3, | ||
"learning_rate": 3e-05, | ||
"warmup_steps": 30, | ||
"logging_steps": 1, | ||
"evaluation_strategy": "epoch", | ||
"save_strategy": "epoch", | ||
"src_length": 1024, | ||
"max_length": 2048, | ||
"fp16": true, | ||
"fp16_opt_level": "O2", | ||
"do_train": true, | ||
"do_eval": true, | ||
"disable_tqdm": true, | ||
"load_best_model_at_end": true, | ||
"eval_with_do_generation": false, | ||
"metric_for_best_model": "accuracy", | ||
"recompute": true, | ||
"save_total_limit": 1, | ||
"tensor_parallel_degree": 4, | ||
"pipeline_parallel_degree": 1, | ||
"intokens": true, | ||
"zero_padding": false, | ||
"use_flash_attention": false | ||
} | ||
"model_name_or_path": "facebook/llama-7b", | ||
"dataset_name_or_path": "./data", | ||
"output_dir": "./checkpoints/llama_sft_ckpts", | ||
"per_device_train_batch_size": 4, | ||
"gradient_accumulation_steps": 4, | ||
"per_device_eval_batch_size": 8, | ||
"eval_accumulation_steps":16, | ||
"num_train_epochs": 3, | ||
"learning_rate": 3e-05, | ||
"warmup_steps": 30, | ||
"logging_steps": 1, | ||
"evaluation_strategy": "epoch", | ||
"save_strategy": "epoch", | ||
"src_length": 1024, | ||
"max_length": 2048, | ||
"fp16": true, | ||
"fp16_opt_level": "O2", | ||
"do_train": true, | ||
"do_eval": true, | ||
"disable_tqdm": true, | ||
"load_best_model_at_end": true, | ||
"eval_with_do_generation": false, | ||
"metric_for_best_model": "accuracy", | ||
"recompute": true, | ||
"save_total_limit": 1, | ||
"tensor_parallel_degree": 4, | ||
"pipeline_parallel_degree": 1, | ||
"intokens": true, | ||
"zero_padding": false, | ||
"use_flash_attention": false | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个json也不需要被修改 |
||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 文件命名,目录命名,小写 |
||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import math | ||
|
||
import paddle | ||
from paddle import Tensor, nn | ||
|
||
__all__ = ["AttentionWithLinearBias"] | ||
|
||
|
||
class AttentionWithLinearBias(nn.Layer): | ||
""" | ||
init_args:bool_attention_mask,num_heads,dtype,tensor_parallel_degree | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
""" | ||
|
||
def __init__(self, **init_args): | ||
""" | ||
**init_args:... | ||
""" | ||
super().__init__() | ||
|
||
def _get_interleave(self, n): | ||
def _get_interleave_power_of_2(n): | ||
start = 2 ** (-(2 ** -(math.log2(n) - 3))) | ||
ratio = start | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个ratio和start相等?是否可以复用 |
||
return [start * ratio**i for i in range(n)] | ||
|
||
if math.log2(n).is_integer(): | ||
return _get_interleave_power_of_2(n) | ||
else: | ||
closest_power_of_2 = 2 ** math.floor(math.log2(n)) | ||
return ( | ||
_get_interleave_power_of_2(closest_power_of_2) | ||
+ self._get_interleave(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] | ||
) | ||
|
||
def forward(self, bool_attention_mask: Tensor, num_heads: int, dtype: paddle.dtype, tensor_parallel_degree=1): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 传入tensor_parallel_degree的用处是? |
||
attention_mask = bool_attention_mask.astype("float32") | ||
batch_size, seq_length = attention_mask.shape[0], attention_mask.shape[-1] | ||
slopes = paddle.to_tensor(self._get_interleave(num_heads), dtype="float32") | ||
alibi = slopes.unsqueeze(axis=[1, 2]) * paddle.arange(seq_length, dtype="float32").unsqueeze( | ||
axis=[0, 1] | ||
).expand([num_heads, -1, -1]) | ||
alibi = alibi.reshape(shape=(1, num_heads, 1, seq_length)).expand([batch_size, -1, -1, -1]) | ||
return paddle.cast(alibi, dtype) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import paddle | ||
from paddle import nn | ||
|
||
__all__ = [ | ||
"RotaryEmbedding", | ||
"LinearScalingRotaryEmbedding", | ||
"NTKScalingRotaryEmbedding", | ||
"DynamicNTKScalingRotaryEmbedding", | ||
] | ||
|
||
|
||
class RotaryEmbedding(nn.Layer): | ||
def __init__(self, **init_args): | ||
super().__init__() | ||
self.dim = init_args["dim"] | ||
self.max_position_embeddings = init_args["max_position_embeddings"] | ||
self.base = init_args["base"] | ||
self.position_encoding_2d = init_args["position_encoding_2d"] if "position_encoding_2d" in init_args else False | ||
if self.position_encoding_2d: | ||
# [dim / 4]# 2D--Embedding | ||
self.dim = self.dim / 2 | ||
inv_freq = 1.0 / ( | ||
self.base ** (paddle.cast(paddle.arange(0, self.dim, 2), dtype=paddle.float32) / self.dim) | ||
) | ||
else: | ||
# [dim / 2] | ||
inv_freq = 1.0 / ( | ||
self.base ** (paddle.cast(paddle.arange(0, self.dim, 2), dtype=paddle.float32) / self.dim) | ||
) | ||
self.register_buffer("inv_freq", inv_freq) | ||
self._set_cos_sin_cache(seq_len=self.max_position_embeddings) | ||
|
||
def _set_cos_sin_cache(self, seq_len): | ||
self.max_seq_len_cached = seq_len | ||
# [seq_len] | ||
t = paddle.arange(seq_len, dtype=paddle.float32) | ||
# [seq_len, dim/2] | ||
with paddle.amp.auto_cast(enable=False): | ||
freqs = paddle.outer(t.astype(self.inv_freq.dtype), self.inv_freq) | ||
# [seq_len, dim] | ||
emb = paddle.concat([freqs, freqs], axis=-1) | ||
self.cos_cached = emb.cos()[:, :] | ||
self.sin_cached = emb.sin()[:, :] | ||
|
||
def forward(self, seq_len=None, ntk_alpha=None): | ||
|
||
return self.cos_cached[:seq_len, :], self.sin_cached[:seq_len, :] | ||
|
||
|
||
class LinearScalingRotaryEmbedding(RotaryEmbedding): | ||
def __init__(self, **init_args): | ||
self.scaling_factor = init_args["scaling_factor"] | ||
super().__init__(**init_args) | ||
|
||
def _set_cos_sin_cache(self, seq_len): | ||
self.max_seq_len_cached = seq_len | ||
# [seq_len] | ||
t = paddle.arange(seq_len, dtype=paddle.float32) | ||
t = t / self.scaling_factor | ||
# [seq_len, dim/2] | ||
with paddle.amp.auto_cast(enable=False): | ||
freqs = paddle.outer(t.astype(self.inv_freq.dtype), self.inv_freq) | ||
# [seq_len, dim] | ||
emb = paddle.concat([freqs, freqs], axis=-1) | ||
self.cos_cached = emb.cos()[:, :] | ||
self.sin_cached = emb.sin()[:, :] | ||
|
||
|
||
class NTKScalingRotaryEmbedding(RotaryEmbedding): | ||
"""RotaryEmbedding extended with NTK scaling. https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/""" | ||
|
||
def __init__(self, **init_args): | ||
init_args["base"] = init_args["base"] * init_args["scaling_factor"] ** ( | ||
init_args["dim"] / (init_args["dim"] - 2) | ||
) | ||
super().__init__(**init_args) | ||
|
||
|
||
class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding): | ||
"""RotaryEmbedding extended with Dynamic NTK scaling. https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/""" | ||
|
||
def __init__(self, **init_args): | ||
self.scaling_factor = init_args["scaling_factor"] | ||
self._seq_len_cached = 0 | ||
super().__init__(**init_args) | ||
|
||
def _scale_cos_sin(self, seq_len, ntk_alpha=None): | ||
# [seq_len] | ||
t = paddle.arange(seq_len, dtype=paddle.float32) | ||
if ntk_alpha is None: | ||
ntk_alpha = (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) | ||
base = self.base * ntk_alpha ** (self.dim / (self.dim - 2)) | ||
|
||
# [seq_len, dim/2] | ||
inv_freq = 1.0 / (base ** (paddle.cast(paddle.arange(0, self.dim, 2), dtype=paddle.float32) / self.dim)) | ||
with paddle.amp.auto_cast(enable=False): | ||
freqs = paddle.outer(t.astype(inv_freq.dtype), inv_freq) | ||
# [seq_len, dim] | ||
emb = paddle.concat([freqs, freqs], axis=-1) | ||
self.cos_cached = emb.cos()[:, :] | ||
self.sin_cached = emb.sin()[:, :] | ||
|
||
def forward(self, seq_len=None, ntk_alpha=None): | ||
|
||
if seq_len > self.max_position_embeddings: | ||
self._scale_cos_sin(seq_len=seq_len, ntk_alpha=ntk_alpha) | ||
|
||
return super().forward(seq_len) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import importlib | ||
|
||
|
||
class LongSequenceStrategies: | ||
@classmethod | ||
def build_long_sequence_strategy(cls, strategy_type=None, stratety_name=None, **init_args): | ||
""" | ||
|
||
**init_args: head_dim, | ||
max_position_embeddings, | ||
rope_scaling_type, | ||
rope_scaling_factor, | ||
... | ||
|
||
strategy_type: "None" ---------------走原始的build-in模块 | ||
"EmbeddingStrategies"、 | ||
"AttentionStrategies" | ||
... | ||
|
||
stratety_name: "RotaryEmbedding"、 | ||
"LinearScalingRotaryEmbedding"、 | ||
"NTKScalingRotaryEmbedding"、 | ||
"DynamicNTKScalingRotaryEmbedding"、 | ||
"AttentionWithLinearBias" | ||
... | ||
|
||
""" | ||
|
||
""" | ||
paddlenlp.transformers.LongSequenceStrategies.{strategy_type<->import_class)}.{stratety_name<->strategy_class)} | ||
paddlenlp.transformers.LongSequenceStrategies.{EmbeddingStrategies}.{RoPE,...} | ||
paddlenlp.transformers.LongSequenceStrategies.{AttentionStrategies}.{ALiBi,...} | ||
""" | ||
try: | ||
import_class = importlib.import_module(f"paddlenlp.transformers.LongSequenceStrategies.{strategy_type}") | ||
except ValueError: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 应该是ModuleNotFoundError? |
||
raise ValueError(f"Wrong strategy type {strategy_type}.") | ||
try: | ||
strategy_class = getattr(import_class, stratety_name) | ||
strategy_instance = strategy_class(**init_args) | ||
return strategy_instance | ||
except AttributeError: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 如果是strategy_class,报的错误是AttributeError? |
||
all_strategy_classes = import_class.__all__ | ||
raise AttributeError( | ||
f"module '{import_class.__name__}' only supports the following classes: " | ||
+ ", ".join(m for m in all_strategy_classes) | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
from .AttentionStrategies import * | ||
from .EmbeddingStrategies import * | ||
from .LongSequenceStrategies import * |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个文件不需要被修改吧