Skip to content

Commit

Permalink
add long sequence strategies (#8076)
Browse files Browse the repository at this point in the history
* test

* remove test.txt

* add long sequence strategies

* long sequence V1

* draft

* draft

* draft new

* add long sequence stratiges

* add long sequence strategies new

* fix format

* fix conflict

* fix format

* fix error

* fix error

* fix format

* fix format

* fix format

* fix format

* close @slow

* fix test

* fix error

* modify try_catch

* fix format

* fix format

* fix format

* add bloom_alibi

* fix error

* add dynamic_to_static

* add dynamic_to_static
  • Loading branch information
WAI-clear authored Mar 26, 2024
1 parent a0457d1 commit 6b5099a
Show file tree
Hide file tree
Showing 15 changed files with 5,496 additions and 22 deletions.
9 changes: 9 additions & 0 deletions paddlenlp/transformers/bloom/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ def __init__(
use_recompute=False,
use_pure_fp16=False,
use_flash_attention=False,
long_sequence_strategy_type=None,
long_sequence_strategy_name=None,
long_sequence_init_args=None,
use_long_sequence_strategies=False,
**kwargs,
):

Expand All @@ -150,3 +154,8 @@ def __init__(
self.use_recompute = use_recompute
self.use_pure_fp16 = use_pure_fp16
self.use_flash_attention = use_flash_attention

self.long_sequence_strategy_type = long_sequence_strategy_type
self.long_sequence_strategy_name = long_sequence_strategy_name
self.long_sequence_init_args = {} if long_sequence_init_args is None else long_sequence_init_args
self.use_long_sequence_strategies = use_long_sequence_strategies
24 changes: 21 additions & 3 deletions paddlenlp/transformers/bloom/modeling.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from paddle.distributed import fleet
from paddle.distributed.fleet.utils import recompute

from paddlenlp.transformers.long_sequence_strategies import LongSequenceStrategies
from paddlenlp.transformers.model_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
Expand Down Expand Up @@ -946,10 +947,27 @@ def forward(
attention_mask = paddle.cast(attention_mask, "bool")
if len(attention_mask.shape) > 2:
_attention_mask = paddle.ones([batch_size, seq_length_with_past], dtype="bool")
alibi = build_alibi_tensor(_attention_mask, self.config.n_head, dtype=hidden_states.dtype)
if self.config.use_long_sequence_strategies:
alibi_layer = LongSequenceStrategies.build_long_sequence_strategy(
self.config.long_sequence_strategy_type,
self.config.long_sequence_strategy_name,
**self.config.long_sequence_init_args,
)
alibi = alibi_layer(_attention_mask, self.config.n_head, dtype=hidden_states.dtype)
alibi = paddle.squeeze(alibi)
else:
alibi = build_alibi_tensor(_attention_mask, self.config.n_head, dtype=hidden_states.dtype)
else:
alibi = build_alibi_tensor(attention_mask, self.config.n_head, dtype=hidden_states.dtype)

if self.config.use_long_sequence_strategies:
alibi_layer = LongSequenceStrategies.build_long_sequence_strategy(
self.config.long_sequence_strategy_type,
self.config.long_sequence_strategy_name,
**self.config.long_sequence_init_args,
)
alibi = alibi_layer(attention_mask, self.config.n_head, dtype=hidden_states.dtype)
alibi = paddle.squeeze(alibi)
else:
alibi = build_alibi_tensor(attention_mask, self.config.n_head, dtype=hidden_states.dtype)
if self.config.tensor_parallel_degree > 1:
block_size = self.config.n_head // self.config.tensor_parallel_degree
alibi = alibi[
Expand Down
9 changes: 8 additions & 1 deletion paddlenlp/transformers/chatglm/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
"CHATGLM_PRETRAINED_RESOURCE_FILES_MAP",
]


CHATGLM_PRETRAINED_RESOURCE_FILES_MAP = {
"model_state": {
"THUDM/chatglm-6b": "https://paddlenlp.bj.bcebos.com/models/community/THUDM/chatglm-6b/model_state.pdparams",
Expand Down Expand Up @@ -104,6 +103,10 @@ def __init__(
activation="gelu",
num_image_tokens=0,
use_flash_attention=False,
long_sequence_strategy_type=None,
long_sequence_strategy_name=None,
long_sequence_init_args=None,
use_long_sequence_strategies=False,
**kwargs
):
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
Expand All @@ -129,3 +132,7 @@ def __init__(
self.activation = activation
self.num_image_tokens = num_image_tokens
self.use_flash_attention = use_flash_attention
self.long_sequence_strategy_type = long_sequence_strategy_type
self.long_sequence_strategy_name = long_sequence_strategy_name
self.long_sequence_init_args = {} if long_sequence_init_args is None else long_sequence_init_args
self.use_long_sequence_strategies = use_long_sequence_strategies
37 changes: 28 additions & 9 deletions paddlenlp/transformers/chatglm/modeling.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from paddle.distributed.fleet.utils import recompute
from paddle.utils import map_structure

from paddlenlp.transformers.long_sequence_strategies import LongSequenceStrategies

from ...utils.env import CONFIG_NAME
from ...utils.log import logger
from .. import PretrainedModel, register_base_model
Expand Down Expand Up @@ -442,12 +444,21 @@ def __init__(self, config: ChatGLMConfig):
# Recompute defaults to False and is controlled by Trainer
self.enable_recompute = False
self.num_attention_heads = config.num_attention_heads
self.rotary_embeddings = RotaryEmbeddings(
self.hidden_size // (self.num_attention_heads * 2)
if self.position_encoding_2d
else self.hidden_size // self.num_attention_heads,
base=10000.0,
)

if config.use_long_sequence_strategies:
self.rotary_embeddings = LongSequenceStrategies.build_long_sequence_strategy(
config.long_sequence_strategy_type,
config.long_sequence_strategy_name,
**config.long_sequence_init_args,
)

else:
self.rotary_embeddings = RotaryEmbeddings(
self.hidden_size // (self.num_attention_heads * 2)
if self.position_encoding_2d
else self.hidden_size // self.num_attention_heads,
base=10000.0,
)
# self.embedding_dropout = nn.Dropout(config.embedding_dropout_prob)

if self.config.tensor_parallel_degree > 1:
Expand Down Expand Up @@ -530,7 +541,6 @@ def forward(
cache: Optional[Tensor] = None,
use_cache: bool = False,
):

if input_ids is not None and inputs_embeds is not None:
input_ids = None
logger.warning("Specify both input_ids and inputs_embeds at the same time, will use inputs_embeds")
Expand All @@ -544,8 +554,17 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
inputs_embeds = inputs_embeds.transpose([1, 0, 2])

rotary_embeds = self.rotary_embeddings(position_ids)
if self.config.use_long_sequence_strategies:
cos, sin = self.rotary_embeddings(seq_len=seq_length)
block_position_ids = position_ids[:, 1, :].transpose([1, 0])
position_ids = position_ids[:, 0, :].transpose([1, 0])
block_rotary_embeds = paddle.stack(
[cos[block_position_ids].unsqueeze(2), sin[block_position_ids].unsqueeze(2)]
)
position_rotary_embeds = paddle.stack([cos[position_ids].unsqueeze(2), sin[position_ids].unsqueeze(2)])
rotary_embeds = paddle.stack([position_rotary_embeds, block_rotary_embeds], axis=0)
else:
rotary_embeds = self.rotary_embeddings(position_ids)

if cache is None:
if self.config.pre_seq_len is not None:
Expand Down
8 changes: 8 additions & 0 deletions paddlenlp/transformers/chatglm_v2/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ def __init__(
eos_token_id=2,
pad_token_id=0,
use_flash_attention=False,
long_sequence_strategy_type=None,
long_sequence_strategy_name=None,
long_sequence_init_args=None,
use_long_sequence_strategies=False,
**kwargs
):
super().__init__(pad_token_id=pad_token_id, eos_token_id=eos_token_id, **kwargs)
Expand All @@ -81,3 +85,7 @@ def __init__(
self.attention_softmax_in_fp32 = attention_softmax_in_fp32
self.fp32_residual_connection = fp32_residual_connection
self.use_flash_attention = use_flash_attention
self.long_sequence_strategy_type = long_sequence_strategy_type
self.long_sequence_strategy_name = long_sequence_strategy_name
self.long_sequence_init_args = {} if long_sequence_init_args is None else long_sequence_init_args
self.use_long_sequence_strategies = use_long_sequence_strategies
22 changes: 19 additions & 3 deletions paddlenlp/transformers/chatglm_v2/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from paddle.distributed.fleet.utils import recompute
from paddle.utils import map_structure

from paddlenlp.transformers.long_sequence_strategies import LongSequenceStrategies

from ...utils.converter import StateDictNameMapping, init_name_mappings
from .. import PretrainedModel, register_base_model
from ..model_outputs import (
Expand Down Expand Up @@ -650,7 +652,15 @@ def __init__(self, config: ChatGLMv2Config, empty_init=True):
rotary_dim = (
config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
)
self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2)
if config.use_long_sequence_strategies:
self.config = config
self.rotary_pos_emb = LongSequenceStrategies.build_long_sequence_strategy(
config.long_sequence_strategy_type,
config.long_sequence_strategy_name,
**config.long_sequence_init_args,
)
else:
self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2)
self.encoder = GLMTransformer(config)
self.output_layer = nn.Linear(config.hidden_size, config.padded_vocab_size, bias_attr=False)

Expand All @@ -677,7 +687,6 @@ def forward(
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

batch_size, seq_length = input_ids.shape

if inputs_embeds is None:
Expand All @@ -686,7 +695,14 @@ def forward(
full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)

# Rotary positional embeddings
rotary_pos_emb = self.rotary_pos_emb(self.max_sequence_length)
if self.config.use_long_sequence_strategies:
cos, sin = self.rotary_pos_emb(seq_len=self.max_sequence_length)
cos, cos = paddle.chunk(cos, 2, axis=-1)
sin, sin = paddle.chunk(sin, 2, axis=-1)
rotary_pos_emb = paddle.stack([cos, sin], axis=-1)
else:
rotary_pos_emb = self.rotary_pos_emb(self.max_sequence_length)

if position_ids is not None:
rotary_pos_emb = rotary_pos_emb[position_ids]
else:
Expand Down
9 changes: 9 additions & 0 deletions paddlenlp/transformers/llama/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,10 @@ def __init__(
alibi=False,
rope_scaling_factor=1.0,
rope_scaling_type=None,
long_sequence_strategy_type=None,
long_sequence_strategy_name=None,
long_sequence_init_args=None,
use_long_sequence_strategies=False,
**kwargs,
):
self.vocab_size = vocab_size
Expand Down Expand Up @@ -208,6 +212,11 @@ def __init__(
self.rope_scaling_factor = rope_scaling_factor
self.rope_scaling_type = rope_scaling_type

self.long_sequence_strategy_type = long_sequence_strategy_type
self.long_sequence_strategy_name = long_sequence_strategy_name
self.long_sequence_init_args = {} if long_sequence_init_args is None else long_sequence_init_args
self.use_long_sequence_strategies = use_long_sequence_strategies

super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
Expand Down
33 changes: 30 additions & 3 deletions paddlenlp/transformers/llama/modeling.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def swiglu(x, y=None):
StateDictNameMapping,
init_name_mappings,
)
from paddlenlp.transformers.long_sequence_strategies import LongSequenceStrategies
from paddlenlp.transformers.model_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
Expand Down Expand Up @@ -763,7 +764,14 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
)

if config.rope:
self._init_rope()
if config.use_long_sequence_strategies:
self.rotary_emb = LongSequenceStrategies.build_long_sequence_strategy(
config.long_sequence_strategy_type,
config.long_sequence_strategy_name,
**config.long_sequence_init_args,
)
else:
self._init_rope()

self.reshard_layer = None
if config.sep_parallel_degree > 1:
Expand Down Expand Up @@ -972,7 +980,17 @@ def forward(
use_neox_rotary_style=False,
)
else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
if self.config.use_long_sequence_strategies:
cos, sin = self.rotary_emb(seq_len=kv_seq_len)
cos = cos[None, :, None, :]
sin = sin[None, :, None, :]
cos, sin = (
cos.cast(value_states.dtype) if cos.dtype != value_states.dtype else cos,
sin.cast(value_states.dtype) if sin.dtype != value_states.dtype else sin,
)
else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)

query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

# [bs, seq_len, num_head, head_dim]
Expand Down Expand Up @@ -1325,6 +1343,7 @@ def __init__(self, config: LlamaConfig):
self.sequence_parallel = config.sequence_parallel
self.recompute_granularity = config.recompute_granularity
self.no_recompute_layers = config.no_recompute_layers if config.no_recompute_layers is not None else []
self.config = config

# Recompute defaults to False and is controlled by Trainer
self.enable_recompute = False
Expand Down Expand Up @@ -1477,7 +1496,15 @@ def forward(
# [bs, seq_len]
attention_mask = paddle.ones((batch_size, seq_length_with_past), dtype=paddle.bool)
if self.config.alibi:
alibi = build_alibi_tensor(attention_mask, self.config.num_attention_heads, dtype=inputs_embeds.dtype)
if self.config.use_long_sequence_strategies:
alibi_layer = LongSequenceStrategies.build_long_sequence_strategy(
self.config.long_sequence_strategy_type,
self.config.long_sequence_strategy_name,
**self.config.long_sequence_init_args,
)
alibi = alibi_layer(attention_mask, self.config.num_attention_heads, dtype=inputs_embeds.dtype)
else:
alibi = build_alibi_tensor(attention_mask, self.config.num_attention_heads, dtype=inputs_embeds.dtype)
if self.config.tensor_parallel_degree > 1:
block_size = self.config.num_attention_heads // self.config.tensor_parallel_degree
alibi = alibi[
Expand Down
18 changes: 18 additions & 0 deletions paddlenlp/transformers/long_sequence_strategies/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from .attention_strategies import *
from .embedding_strategies import *
from .long_sequence_strategies import *
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math

import numpy as np
import paddle
from paddle import Tensor, nn

__all__ = ["AttentionWithLinearBias"]


class AttentionWithLinearBias(nn.Layer):
def __init__(self, **init_args):
super().__init__()

def _get_interleave(self, n):
def _get_interleave_power_of_2(n):
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
return np.array([start * start**i for i in range(n)]).astype(np.float32)

if math.log2(n).is_integer():
return _get_interleave_power_of_2(n)
else:
closest_power_of_2 = 2 ** math.floor(math.log2(n))
return (
_get_interleave_power_of_2(closest_power_of_2)
+ self._get_interleave(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
)

def forward(self, bool_attention_mask: Tensor, num_heads: int, dtype: paddle.dtype):
attention_mask = bool_attention_mask.astype("float32")
batch_size, seq_length = attention_mask.shape[0], attention_mask.shape[-1]
slopes = paddle.to_tensor(self._get_interleave(num_heads), dtype="float32")
with paddle.amp.auto_cast(enable=False):
alibi = slopes.unsqueeze(axis=[1, 2]) * paddle.arange(seq_length, dtype="float32").unsqueeze(
axis=[0, 1]
).expand([num_heads, -1, -1])
alibi = alibi.reshape(shape=(1, num_heads, 1, seq_length)).expand([batch_size, -1, -1, -1])
return paddle.cast(alibi, dtype)
Loading

0 comments on commit 6b5099a

Please sign in to comment.