diff --git a/paddlenlp/transformers/bloom/configuration.py b/paddlenlp/transformers/bloom/configuration.py index 03f3f8e8d2c3..5b430f3cad51 100644 --- a/paddlenlp/transformers/bloom/configuration.py +++ b/paddlenlp/transformers/bloom/configuration.py @@ -125,6 +125,10 @@ def __init__( use_recompute=False, use_pure_fp16=False, use_flash_attention=False, + long_sequence_strategy_type=None, + long_sequence_strategy_name=None, + long_sequence_init_args=None, + use_long_sequence_strategies=False, **kwargs, ): @@ -150,3 +154,8 @@ def __init__( self.use_recompute = use_recompute self.use_pure_fp16 = use_pure_fp16 self.use_flash_attention = use_flash_attention + + self.long_sequence_strategy_type = long_sequence_strategy_type + self.long_sequence_strategy_name = long_sequence_strategy_name + self.long_sequence_init_args = {} if long_sequence_init_args is None else long_sequence_init_args + self.use_long_sequence_strategies = use_long_sequence_strategies diff --git a/paddlenlp/transformers/bloom/modeling.py b/paddlenlp/transformers/bloom/modeling.py old mode 100644 new mode 100755 index 1f7ad2d299f9..d729225b7ccc --- a/paddlenlp/transformers/bloom/modeling.py +++ b/paddlenlp/transformers/bloom/modeling.py @@ -26,6 +26,7 @@ from paddle.distributed import fleet from paddle.distributed.fleet.utils import recompute +from paddlenlp.transformers.long_sequence_strategies import LongSequenceStrategies from paddlenlp.transformers.model_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, @@ -944,10 +945,27 @@ def forward( attention_mask = paddle.cast(attention_mask, "bool") if len(attention_mask.shape) > 2: _attention_mask = paddle.ones([batch_size, seq_length_with_past], dtype="bool") - alibi = build_alibi_tensor(_attention_mask, self.config.n_head, dtype=hidden_states.dtype) + if self.config.use_long_sequence_strategies: + alibi_layer = LongSequenceStrategies.build_long_sequence_strategy( + self.config.long_sequence_strategy_type, + self.config.long_sequence_strategy_name, + **self.config.long_sequence_init_args, + ) + alibi = alibi_layer(_attention_mask, self.config.n_head, dtype=hidden_states.dtype) + alibi = paddle.squeeze(alibi) + else: + alibi = build_alibi_tensor(_attention_mask, self.config.n_head, dtype=hidden_states.dtype) else: - alibi = build_alibi_tensor(attention_mask, self.config.n_head, dtype=hidden_states.dtype) - + if self.config.use_long_sequence_strategies: + alibi_layer = LongSequenceStrategies.build_long_sequence_strategy( + self.config.long_sequence_strategy_type, + self.config.long_sequence_strategy_name, + **self.config.long_sequence_init_args, + ) + alibi = alibi_layer(attention_mask, self.config.n_head, dtype=hidden_states.dtype) + alibi = paddle.squeeze(alibi) + else: + alibi = build_alibi_tensor(attention_mask, self.config.n_head, dtype=hidden_states.dtype) if self.config.tensor_parallel_degree > 1: block_size = self.config.n_head // self.config.tensor_parallel_degree alibi = alibi[ diff --git a/paddlenlp/transformers/chatglm/configuration.py b/paddlenlp/transformers/chatglm/configuration.py index 2b2af745a963..e7d86b1274a2 100644 --- a/paddlenlp/transformers/chatglm/configuration.py +++ b/paddlenlp/transformers/chatglm/configuration.py @@ -21,7 +21,6 @@ "CHATGLM_PRETRAINED_RESOURCE_FILES_MAP", ] - CHATGLM_PRETRAINED_RESOURCE_FILES_MAP = { "model_state": { "THUDM/chatglm-6b": "https://paddlenlp.bj.bcebos.com/models/community/THUDM/chatglm-6b/model_state.pdparams", @@ -104,6 +103,10 @@ def __init__( activation="gelu", num_image_tokens=0, use_flash_attention=False, + long_sequence_strategy_type=None, + long_sequence_strategy_name=None, + long_sequence_init_args=None, + use_long_sequence_strategies=False, **kwargs ): super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) @@ -129,3 +132,7 @@ def __init__( self.activation = activation self.num_image_tokens = num_image_tokens self.use_flash_attention = use_flash_attention + self.long_sequence_strategy_type = long_sequence_strategy_type + self.long_sequence_strategy_name = long_sequence_strategy_name + self.long_sequence_init_args = {} if long_sequence_init_args is None else long_sequence_init_args + self.use_long_sequence_strategies = use_long_sequence_strategies diff --git a/paddlenlp/transformers/chatglm/modeling.py b/paddlenlp/transformers/chatglm/modeling.py old mode 100644 new mode 100755 index dd103d070642..879ee71b2ab0 --- a/paddlenlp/transformers/chatglm/modeling.py +++ b/paddlenlp/transformers/chatglm/modeling.py @@ -27,6 +27,8 @@ from paddle.distributed.fleet.utils import recompute from paddle.utils import map_structure +from paddlenlp.transformers.long_sequence_strategies import LongSequenceStrategies + from ...utils.env import CONFIG_NAME from ...utils.log import logger from .. import PretrainedModel, register_base_model @@ -442,12 +444,21 @@ def __init__(self, config: ChatGLMConfig): # Recompute defaults to False and is controlled by Trainer self.enable_recompute = False self.num_attention_heads = config.num_attention_heads - self.rotary_embeddings = RotaryEmbeddings( - self.hidden_size // (self.num_attention_heads * 2) - if self.position_encoding_2d - else self.hidden_size // self.num_attention_heads, - base=10000.0, - ) + + if config.use_long_sequence_strategies: + self.rotary_embeddings = LongSequenceStrategies.build_long_sequence_strategy( + config.long_sequence_strategy_type, + config.long_sequence_strategy_name, + **config.long_sequence_init_args, + ) + + else: + self.rotary_embeddings = RotaryEmbeddings( + self.hidden_size // (self.num_attention_heads * 2) + if self.position_encoding_2d + else self.hidden_size // self.num_attention_heads, + base=10000.0, + ) # self.embedding_dropout = nn.Dropout(config.embedding_dropout_prob) if self.config.tensor_parallel_degree > 1: @@ -530,7 +541,6 @@ def forward( cache: Optional[Tensor] = None, use_cache: bool = False, ): - if input_ids is not None and inputs_embeds is not None: input_ids = None logger.warning("Specify both input_ids and inputs_embeds at the same time, will use inputs_embeds") @@ -544,8 +554,17 @@ def forward( if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) inputs_embeds = inputs_embeds.transpose([1, 0, 2]) - - rotary_embeds = self.rotary_embeddings(position_ids) + if self.config.use_long_sequence_strategies: + cos, sin = self.rotary_embeddings(seq_len=seq_length) + block_position_ids = position_ids[:, 1, :].transpose([1, 0]) + position_ids = position_ids[:, 0, :].transpose([1, 0]) + block_rotary_embeds = paddle.stack( + [cos[block_position_ids].unsqueeze(2), sin[block_position_ids].unsqueeze(2)] + ) + position_rotary_embeds = paddle.stack([cos[position_ids].unsqueeze(2), sin[position_ids].unsqueeze(2)]) + rotary_embeds = paddle.stack([position_rotary_embeds, block_rotary_embeds], axis=0) + else: + rotary_embeds = self.rotary_embeddings(position_ids) if cache is None: if self.config.pre_seq_len is not None: diff --git a/paddlenlp/transformers/chatglm_v2/configuration.py b/paddlenlp/transformers/chatglm_v2/configuration.py index e51322f11e0f..d8e8147eb258 100644 --- a/paddlenlp/transformers/chatglm_v2/configuration.py +++ b/paddlenlp/transformers/chatglm_v2/configuration.py @@ -56,6 +56,10 @@ def __init__( eos_token_id=2, pad_token_id=0, use_flash_attention=False, + long_sequence_strategy_type=None, + long_sequence_strategy_name=None, + long_sequence_init_args=None, + use_long_sequence_strategies=False, **kwargs ): super().__init__(pad_token_id=pad_token_id, eos_token_id=eos_token_id, **kwargs) @@ -81,3 +85,7 @@ def __init__( self.attention_softmax_in_fp32 = attention_softmax_in_fp32 self.fp32_residual_connection = fp32_residual_connection self.use_flash_attention = use_flash_attention + self.long_sequence_strategy_type = long_sequence_strategy_type + self.long_sequence_strategy_name = long_sequence_strategy_name + self.long_sequence_init_args = {} if long_sequence_init_args is None else long_sequence_init_args + self.use_long_sequence_strategies = use_long_sequence_strategies diff --git a/paddlenlp/transformers/chatglm_v2/modeling.py b/paddlenlp/transformers/chatglm_v2/modeling.py index 01170c19c155..bbfb6e52f481 100644 --- a/paddlenlp/transformers/chatglm_v2/modeling.py +++ b/paddlenlp/transformers/chatglm_v2/modeling.py @@ -21,6 +21,8 @@ from paddle.distributed.fleet.utils import recompute from paddle.utils import map_structure +from paddlenlp.transformers.long_sequence_strategies import LongSequenceStrategies + from ...utils.converter import StateDictNameMapping, init_name_mappings from .. import PretrainedModel, register_base_model from ..model_outputs import ( @@ -650,7 +652,15 @@ def __init__(self, config: ChatGLMv2Config, empty_init=True): rotary_dim = ( config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels ) - self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2) + if config.use_long_sequence_strategies: + self.config = config + self.rotary_pos_emb = LongSequenceStrategies.build_long_sequence_strategy( + config.long_sequence_strategy_type, + config.long_sequence_strategy_name, + **config.long_sequence_init_args, + ) + else: + self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2) self.encoder = GLMTransformer(config) self.output_layer = nn.Linear(config.hidden_size, config.padded_vocab_size, bias_attr=False) @@ -677,7 +687,6 @@ def forward( ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict - batch_size, seq_length = input_ids.shape if inputs_embeds is None: @@ -686,7 +695,14 @@ def forward( full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) # Rotary positional embeddings - rotary_pos_emb = self.rotary_pos_emb(self.max_sequence_length) + if self.config.use_long_sequence_strategies: + cos, sin = self.rotary_pos_emb(seq_len=self.max_sequence_length) + cos, cos = paddle.chunk(cos, 2, axis=-1) + sin, sin = paddle.chunk(sin, 2, axis=-1) + rotary_pos_emb = paddle.stack([cos, sin], axis=-1) + else: + rotary_pos_emb = self.rotary_pos_emb(self.max_sequence_length) + if position_ids is not None: rotary_pos_emb = rotary_pos_emb[position_ids] else: diff --git a/paddlenlp/transformers/llama/configuration.py b/paddlenlp/transformers/llama/configuration.py index 8303d4df0482..68459f025fe4 100644 --- a/paddlenlp/transformers/llama/configuration.py +++ b/paddlenlp/transformers/llama/configuration.py @@ -168,6 +168,10 @@ def __init__( alibi=False, rope_scaling_factor=1.0, rope_scaling_type=None, + long_sequence_strategy_type=None, + long_sequence_strategy_name=None, + long_sequence_init_args=None, + use_long_sequence_strategies=False, **kwargs, ): self.vocab_size = vocab_size @@ -208,6 +212,11 @@ def __init__( self.rope_scaling_factor = rope_scaling_factor self.rope_scaling_type = rope_scaling_type + self.long_sequence_strategy_type = long_sequence_strategy_type + self.long_sequence_strategy_name = long_sequence_strategy_name + self.long_sequence_init_args = {} if long_sequence_init_args is None else long_sequence_init_args + self.use_long_sequence_strategies = use_long_sequence_strategies + super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py old mode 100644 new mode 100755 index 94d7f5b1ef1a..cf6bde9bbe34 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -50,6 +50,7 @@ def swiglu(x, y=None): StateDictNameMapping, init_name_mappings, ) +from paddlenlp.transformers.long_sequence_strategies import LongSequenceStrategies from paddlenlp.transformers.model_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, @@ -762,7 +763,14 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False): ) if config.rope: - self._init_rope() + if config.use_long_sequence_strategies: + self.rotary_emb = LongSequenceStrategies.build_long_sequence_strategy( + config.long_sequence_strategy_type, + config.long_sequence_strategy_name, + **config.long_sequence_init_args, + ) + else: + self._init_rope() self.reshard_layer = None if config.sep_parallel_degree > 1: @@ -971,7 +979,17 @@ def forward( use_neox_rotary_style=False, ) else: - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + if self.config.use_long_sequence_strategies: + cos, sin = self.rotary_emb(seq_len=kv_seq_len) + cos = cos[None, :, None, :] + sin = sin[None, :, None, :] + cos, sin = ( + cos.cast(value_states.dtype) if cos.dtype != value_states.dtype else cos, + sin.cast(value_states.dtype) if sin.dtype != value_states.dtype else sin, + ) + else: + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) # [bs, seq_len, num_head, head_dim] @@ -1324,6 +1342,7 @@ def __init__(self, config: LlamaConfig): self.sequence_parallel = config.sequence_parallel self.recompute_granularity = config.recompute_granularity self.no_recompute_layers = config.no_recompute_layers if config.no_recompute_layers is not None else [] + self.config = config # Recompute defaults to False and is controlled by Trainer self.enable_recompute = False @@ -1476,7 +1495,15 @@ def forward( # [bs, seq_len] attention_mask = paddle.ones((batch_size, seq_length_with_past), dtype=paddle.bool) if self.config.alibi: - alibi = build_alibi_tensor(attention_mask, self.config.num_attention_heads, dtype=inputs_embeds.dtype) + if self.config.use_long_sequence_strategies: + alibi_layer = LongSequenceStrategies.build_long_sequence_strategy( + self.config.long_sequence_strategy_type, + self.config.long_sequence_strategy_name, + **self.config.long_sequence_init_args, + ) + alibi = alibi_layer(attention_mask, self.config.num_attention_heads, dtype=inputs_embeds.dtype) + else: + alibi = build_alibi_tensor(attention_mask, self.config.num_attention_heads, dtype=inputs_embeds.dtype) if self.config.tensor_parallel_degree > 1: block_size = self.config.num_attention_heads // self.config.tensor_parallel_degree alibi = alibi[ diff --git a/paddlenlp/transformers/long_sequence_strategies/__init__.py b/paddlenlp/transformers/long_sequence_strategies/__init__.py new file mode 100644 index 000000000000..de784444a17c --- /dev/null +++ b/paddlenlp/transformers/long_sequence_strategies/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .attention_strategies import * +from .embedding_strategies import * +from .long_sequence_strategies import * diff --git a/paddlenlp/transformers/long_sequence_strategies/attention_strategies.py b/paddlenlp/transformers/long_sequence_strategies/attention_strategies.py new file mode 100755 index 000000000000..3b19d19452e7 --- /dev/null +++ b/paddlenlp/transformers/long_sequence_strategies/attention_strategies.py @@ -0,0 +1,51 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import numpy as np +import paddle +from paddle import Tensor, nn + +__all__ = ["AttentionWithLinearBias"] + + +class AttentionWithLinearBias(nn.Layer): + def __init__(self, **init_args): + super().__init__() + + def _get_interleave(self, n): + def _get_interleave_power_of_2(n): + start = 2 ** (-(2 ** -(math.log2(n) - 3))) + return np.array([start * start**i for i in range(n)]).astype(np.float32) + + if math.log2(n).is_integer(): + return _get_interleave_power_of_2(n) + else: + closest_power_of_2 = 2 ** math.floor(math.log2(n)) + return ( + _get_interleave_power_of_2(closest_power_of_2) + + self._get_interleave(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] + ) + + def forward(self, bool_attention_mask: Tensor, num_heads: int, dtype: paddle.dtype): + attention_mask = bool_attention_mask.astype("float32") + batch_size, seq_length = attention_mask.shape[0], attention_mask.shape[-1] + slopes = paddle.to_tensor(self._get_interleave(num_heads), dtype="float32") + with paddle.amp.auto_cast(enable=False): + alibi = slopes.unsqueeze(axis=[1, 2]) * paddle.arange(seq_length, dtype="float32").unsqueeze( + axis=[0, 1] + ).expand([num_heads, -1, -1]) + alibi = alibi.reshape(shape=(1, num_heads, 1, seq_length)).expand([batch_size, -1, -1, -1]) + return paddle.cast(alibi, dtype) diff --git a/paddlenlp/transformers/long_sequence_strategies/embedding_strategies.py b/paddlenlp/transformers/long_sequence_strategies/embedding_strategies.py new file mode 100755 index 000000000000..6e9291e0d951 --- /dev/null +++ b/paddlenlp/transformers/long_sequence_strategies/embedding_strategies.py @@ -0,0 +1,122 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle import nn + +__all__ = [ + "RotaryEmbedding", + "LinearScalingRotaryEmbedding", + "NTKScalingRotaryEmbedding", + "DynamicNTKScalingRotaryEmbedding", +] + + +class RotaryEmbedding(nn.Layer): + def __init__(self, **init_args): + super().__init__() + self.dim = init_args["dim"] + self.max_position_embeddings = init_args["max_position_embeddings"] + self.base = init_args["base"] + self.position_encoding_2d = init_args["position_encoding_2d"] if "position_encoding_2d" in init_args else False + if self.position_encoding_2d: + # [dim / 4]# 2D--Embedding + self.dim = self.dim / 2 + inv_freq = 1.0 / ( + self.base ** (paddle.cast(paddle.arange(0, self.dim, 2), dtype=paddle.float32) / self.dim) + ) + else: + # [dim / 2] + inv_freq = 1.0 / ( + self.base ** (paddle.cast(paddle.arange(0, self.dim, 2), dtype=paddle.float32) / self.dim) + ) + self.register_buffer("inv_freq", inv_freq) + self._set_cos_sin_cache(seq_len=self.max_position_embeddings) + + def _set_cos_sin_cache(self, seq_len): + self.max_seq_len_cached = seq_len + # [seq_len] + t = paddle.arange(seq_len, dtype=paddle.float32) + # [seq_len, dim/2] + with paddle.amp.auto_cast(enable=False): + freqs = paddle.outer(t.astype(self.inv_freq.dtype), self.inv_freq) + # [seq_len, dim] + emb = paddle.concat([freqs, freqs], axis=-1) + self.cos_cached = emb.cos()[:, :] + self.sin_cached = emb.sin()[:, :] + + def forward(self, seq_len=None, ntk_alpha=None): + + return self.cos_cached[:, :], self.sin_cached[:, :] + + +class LinearScalingRotaryEmbedding(RotaryEmbedding): + def __init__(self, **init_args): + self.scaling_factor = init_args["scaling_factor"] + super().__init__(**init_args) + + def _set_cos_sin_cache(self, seq_len): + self.max_seq_len_cached = seq_len + # [seq_len] + t = paddle.arange(seq_len, dtype=paddle.float32) + t = t / self.scaling_factor + # [seq_len, dim/2] + with paddle.amp.auto_cast(enable=False): + freqs = paddle.outer(t.astype(self.inv_freq.dtype), self.inv_freq) + # [seq_len, dim] + emb = paddle.concat([freqs, freqs], axis=-1) + self.cos_cached = emb.cos()[:, :] + self.sin_cached = emb.sin()[:, :] + + +class NTKScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with NTK scaling. https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/""" + + def __init__(self, **init_args): + init_args["base"] = init_args["base"] * init_args["scaling_factor"] ** ( + init_args["dim"] / (init_args["dim"] - 2) + ) + super().__init__(**init_args) + + +class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with Dynamic NTK scaling. https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/""" + + def __init__(self, **init_args): + self.scaling_factor = init_args["scaling_factor"] + self._seq_len_cached = 0 + super().__init__(**init_args) + + def _scale_cos_sin(self, seq_len, ntk_alpha=None): + # [seq_len] + t = paddle.arange(seq_len, dtype=paddle.float32) + if ntk_alpha is None: + ntk_alpha = (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) + base = self.base * ntk_alpha ** (self.dim / (self.dim - 2)) + + # [seq_len, dim/2] + inv_freq = 1.0 / (base ** (paddle.cast(paddle.arange(0, self.dim, 2), dtype=paddle.float32) / self.dim)) + with paddle.amp.auto_cast(enable=False): + freqs = paddle.outer(t.astype(inv_freq.dtype), inv_freq) + # [seq_len, dim] + emb = paddle.concat([freqs, freqs], axis=-1) + self.cos_cached = emb.cos()[:, :] + self.sin_cached = emb.sin()[:, :] + + def forward(self, seq_len=None, ntk_alpha=None): + + if seq_len > self.max_position_embeddings: + self._scale_cos_sin(seq_len=seq_len, ntk_alpha=ntk_alpha) + + return self.cos_cached[:, :], self.sin_cached[:, :] diff --git a/paddlenlp/transformers/long_sequence_strategies/long_sequence_strategies.py b/paddlenlp/transformers/long_sequence_strategies/long_sequence_strategies.py new file mode 100644 index 000000000000..286be7c5f8f8 --- /dev/null +++ b/paddlenlp/transformers/long_sequence_strategies/long_sequence_strategies.py @@ -0,0 +1,66 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib + +all_strategy_types = ["embedding_strategies", "attention_strategies"] + + +class LongSequenceStrategies: + @classmethod + def build_long_sequence_strategy(cls, strategy_type=None, stratety_name=None, **init_args): + """ + + **init_args: head_dim, + max_position_embeddings, + rope_scaling_type, + rope_scaling_factor, + ... + + strategy_type: "None" ---------------走原始的build-in模块 + "embedding_strategies"、 + "attention_strategies" + ... + + stratety_name: "RotaryEmbedding"、 + "LinearScalingRotaryEmbedding"、 + "NTKScalingRotaryEmbedding"、 + "DynamicNTKScalingRotaryEmbedding"、 + "AttentionWithLinearBias" + ... + + """ + + """ + paddlenlp.transformers.long_sequence_strategies.{strategy_type<->import_class)}.{stratety_name<->strategy_class)} + paddlenlp.transformers.long_sequence_strategies.{embedding_strategies}.{RoPE,...} + paddlenlp.transformers.long_sequence_strategies.{attention_strategies}.{ALiBi,...} + """ + try: + import_class = importlib.import_module(f"paddlenlp.transformers.long_sequence_strategies.{strategy_type}") + except ModuleNotFoundError: + raise ModuleNotFoundError( + f"Wrong strategy type {strategy_type}. module only supports the following types: " + + ", ".join(m for m in all_strategy_types) + ) + try: + strategy_class = getattr(import_class, stratety_name) + except: + all_strategy_classes = import_class.__all__ + raise LookupError( + f"module '{import_class.__name__}' only supports the following classes: " + + ", ".join(m for m in all_strategy_classes) + ) + strategy_instance = strategy_class(**init_args) + return strategy_instance diff --git a/paddlenlp/transformers/qwen/configuration.py b/paddlenlp/transformers/qwen/configuration.py index 8950519fb96d..d61a93fd6b70 100644 --- a/paddlenlp/transformers/qwen/configuration.py +++ b/paddlenlp/transformers/qwen/configuration.py @@ -50,6 +50,10 @@ def __init__( pad_token_id=0, bos_token_id=1, eos_token_id=2, + long_sequence_strategy_type=None, + long_sequence_strategy_name=None, + long_sequence_init_args=None, + use_long_sequence_strategies=False, **kwargs, ): self.vocab_size = vocab_size @@ -75,6 +79,11 @@ def __init__( self.use_fused_rope = use_fused_rope self.no_bias = no_bias + self.long_sequence_strategy_type = long_sequence_strategy_type + self.long_sequence_strategy_name = long_sequence_strategy_name + self.long_sequence_init_args = {} if long_sequence_init_args is None else long_sequence_init_args + self.use_long_sequence_strategies = use_long_sequence_strategies + super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, diff --git a/paddlenlp/transformers/qwen/modeling.py b/paddlenlp/transformers/qwen/modeling.py old mode 100644 new mode 100755 index 1114b3d7fc71..8cb171ab2ee1 --- a/paddlenlp/transformers/qwen/modeling.py +++ b/paddlenlp/transformers/qwen/modeling.py @@ -25,6 +25,7 @@ from paddle.distributed.fleet.utils import recompute from paddle.utils import try_import +from paddlenlp.transformers.long_sequence_strategies import LongSequenceStrategies from paddlenlp.transformers.model_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -145,7 +146,14 @@ def __init__(self, config): assert config.rotary_pct < 1 self.rotary_ndims = int(self.hidden_size_per_attention_head * config.rotary_pct) dim = self.rotary_ndims if self.rotary_ndims is not None else self.hidden_size_per_attention_head - self.rotary_emb = RotaryEmbedding(dim, base=config.rotary_emb_base) + if config.use_long_sequence_strategies: + self.rotary_emb = LongSequenceStrategies.build_long_sequence_strategy( + config.long_sequence_strategy_type, + config.long_sequence_strategy_name, + **config.long_sequence_init_args, + ) + else: + self.rotary_emb = RotaryEmbedding(dim, base=config.rotary_emb_base) self.use_dynamic_ntk = config.use_dynamic_ntk self.use_logn_attn = config.use_logn_attn @@ -243,7 +251,6 @@ def forward( query = self._split_heads(query, self.num_heads, self.head_dim) key = self._split_heads(key, self.num_heads, self.head_dim) value = self._split_heads(value, self.num_heads, self.head_dim) - kv_seq_len = hidden_states.shape[1] if layer_past: # layer past[0] shape: bs * seq_len * head_num * dim @@ -255,7 +262,11 @@ def forward( self._ntk_cached = ntk_alpha else: ntk_alpha = self._ntk_cached - rotary_pos_emb = self.rotary_emb(value, kv_seq_len, ntk_alpha=ntk_alpha) + if self.config.use_long_sequence_strategies: + cos, sin = self.rotary_emb(seq_len=kv_seq_len, ntk_alpha=ntk_alpha) + rotary_pos_emb = (cos[None, :, None, :], sin[None, :, None, :]) + else: + rotary_pos_emb = self.rotary_emb(value, kv_seq_len, ntk_alpha=ntk_alpha) if rotary_pos_emb is not None: if isinstance(rotary_pos_emb, tuple): @@ -537,8 +548,6 @@ def _init_weights(self, module): module.weight.set_value( paddle.tensor.normal(mean=0.0, std=self.config.initializer_range, shape=module.weight.shape) ) - if getattr(module, "bias", None) is not None: - module.weight.set_value(paddle.zeros(shape=module.weight.shape, dtype=paddle.get_default_dtype())) for name, p in module.named_parameters(): if name == "c_proj.weight": diff --git a/tests/llm/test_long_sequence_strategies.py b/tests/llm/test_long_sequence_strategies.py new file mode 100644 index 000000000000..f7385feee640 --- /dev/null +++ b/tests/llm/test_long_sequence_strategies.py @@ -0,0 +1,5084 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import os +import sys +import unittest + +import numpy as np +import paddle +from parameterized import parameterized_class + +from paddlenlp.transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer + +from .testing_utils import LLMTest, argv_context_guard, load_test_config + +all_inputs = [ + # llama-7b + [ + [ + 1, + 910, + 3461, + 8128, + 3239, + 2472, + 322, + 5626, + 363, + 11559, + 373, + 2473, + 6360, + 9580, + 545, + 358, + 313, + 17870, + 29925, + 29897, + 14974, + 6360, + 9580, + 545, + 358, + 313, + 17870, + 29925, + 29897, + 322, + 2908, + 15649, + 8078, + 292, + 313, + 29933, + 5371, + 29897, + 526, + 4266, + 8078, + 292, + 7208, + 12903, + 393, + 11559, + 3635, + 1169, + 278, + 10317, + 310, + 5282, + 1947, + 313, + 3970, + 29928, + 29897, + 304, + ] + ], + # qwen-7b + [ + [ + 1986, + 1895, + 5707, + 4004, + 1995, + 323, + 4714, + 369, + 7992, + 389, + 7299, + 3157, + 52578, + 320, + 44, + 9954, + 8, + 323, + 2504, + 3695, + 20358, + 3157, + 52578, + 320, + 44, + 9954, + 8, + 323, + 2504, + 3695, + 59406, + 320, + 66755, + 8, + 525, + 3281, + 59406, + 23783, + 429, + 7992, + 28690, + 279, + 5887, + 315, + 16373, + 320, + 35, + 2069, + 8, + 311, + 990, + 369, + 264, + 7199, + 1372, + 315, + 9055, + 23390, + ] + ], + # chatglm3-6b + [ + [ + 64790, + 64792, + 666, + 1284, + 2736, + 4467, + 1097, + 293, + 2326, + 332, + 4168, + 331, + 5332, + 2475, + 23355, + 359, + 26594, + 30947, + 30945, + 293, + 15903, + 2475, + 23355, + 359, + 26594, + 30947, + 30945, + 293, + 3579, + 2505, + 26317, + 359, + 54223, + 30945, + 383, + 1720, + 26317, + 11972, + 343, + 4168, + 15125, + 267, + 2902, + 290, + 10196, + 359, + 30952, + 3809, + 30945, + 289, + 792, + 332, + 260, + 3666, + 1276, + 290, + 5735, + 10625, + ] + ], + # chatglm-6b + [ + [ + 200, + 647, + 986, + 1186, + 320, + 102, + 953, + 108, + 2355, + 111, + 1297, + 626, + 26020, + 19, + 10806, + 266, + 14, + 102, + 130001, + 130004, + 6723, + 626, + 26020, + 19, + 10806, + 266, + 14, + 102, + 1204, + 1784, + 27817, + 19, + 27798, + 14, + 118, + 972, + 27817, + 2055, + 109, + 2355, + 9187, + 100, + 1334, + 101, + 7319, + 19, + 9220, + 234, + 14, + 103, + 179, + 108, + 104, + 1132, + 277, + 101, + 2576, + 6225, + ] + ], + # bloom + [ + [ + 55, + 75, + 76, + 86, + 210, + 85, + 72, + 83, + 82, + 85, + 87, + 210, + 83, + 85, + 82, + 89, + 76, + 71, + 72, + 86, + 48, + 88, + 79, + 87, + 76, + 92, + 72, + 68, + 85, + 210, + 83, + 85, + 82, + 70, + 88, + 85, + 72, + 80, + 72, + 81, + 87, + 210, + 11, + 48, + 60, + 51, + 12, + 210, + 68, + 81, + 71, + 210, + 69, + 79, + 82, + 70, + 78, + 210, + ] + ], +] +all_position_ids = [ + # llama-7b + [ + [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + 27, + 28, + 29, + 30, + 31, + 32, + 33, + 34, + 35, + 36, + 37, + 38, + 39, + 40, + 41, + 42, + 43, + 44, + 45, + 46, + 47, + 48, + 49, + 50, + 51, + 52, + 53, + 54, + 55, + 56, + 57, + ] + ], + # qwen07b + [ + [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + 27, + 28, + 29, + 30, + 31, + 32, + 33, + 34, + 35, + 36, + 37, + 38, + 39, + 40, + 41, + 42, + 43, + 44, + 45, + 46, + 47, + 48, + 49, + 50, + 51, + 52, + 53, + 54, + 55, + 56, + 57, + ] + ], + # chatglm3-6b + [ + [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + 27, + 28, + 29, + 30, + 31, + 32, + 33, + 34, + 35, + 36, + 37, + 38, + 39, + 40, + 41, + 42, + 43, + 44, + 45, + 46, + 47, + 48, + 49, + 50, + 51, + 52, + 53, + 54, + 55, + 56, + 57, + ] + ], + # chatglm-6b + [ + [ + [ + 18, + 18, + 18, + 18, + 18, + 18, + 18, + 18, + 18, + 18, + 18, + 18, + 18, + 18, + 18, + 18, + 18, + 18, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + 27, + 28, + 29, + 30, + 31, + 32, + 33, + 34, + 35, + 36, + 37, + 38, + 39, + 40, + 41, + 42, + 43, + 44, + 45, + 46, + 47, + 48, + 49, + 50, + 51, + 52, + 53, + 54, + 55, + 56, + 57, + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + 27, + 28, + 29, + 30, + 31, + 32, + 33, + 34, + 35, + 36, + 37, + 38, + 39, + ], + ] + ], + # bloom + [ + [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + 27, + 28, + 29, + 30, + 31, + 32, + 33, + 34, + 35, + 36, + 37, + 38, + 39, + 40, + 41, + 42, + 43, + 44, + 45, + 46, + 47, + 48, + 49, + 50, + 51, + 52, + 53, + 54, + 55, + 56, + 57, + ] + ], +] +all_attention_mask = [ + # llama + [ + [ + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + ] + ], + # qwen + [ + [ + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + ] + ], + # chatglm3-6b + [ + [ + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + ] + ], + # chatglm-6b + [ + [ + [ + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + ], + ] + ] + ], + # bloom + [ + [ + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + ] + ], +] +all_labels = [ + # llama + [ + [ + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + 14974, + 6360, + 9580, + 545, + 358, + 313, + 17870, + 29925, + 29897, + 322, + 2908, + 15649, + 8078, + 292, + 313, + 29933, + 5371, + 29897, + 526, + 4266, + 8078, + 292, + 7208, + 12903, + 393, + 11559, + 3635, + 1169, + 278, + 10317, + 310, + 5282, + 1947, + 313, + 3970, + 29928, + 29897, + 304, + 671, + ] + ], + # qwen + [ + [ + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + 20358, + 3157, + 52578, + 320, + 44, + 9954, + 8, + 323, + 2504, + 3695, + 59406, + 320, + 66755, + 8, + 525, + 3281, + 59406, + 23783, + 429, + 7992, + 28690, + 279, + 5887, + 315, + 16373, + 320, + 35, + 2069, + 8, + 311, + 990, + 369, + 264, + 7199, + 1372, + 315, + 9055, + 23390, + 7468, + ] + ], + # chatglm3-6b + [ + [ + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + 15903, + 2475, + 23355, + 359, + 26594, + 30947, + 30945, + 293, + 3579, + 2505, + 26317, + 359, + 54223, + 30945, + 383, + 1720, + 26317, + 11972, + 343, + 4168, + 15125, + 267, + 2902, + 290, + 10196, + 359, + 30952, + 3809, + 30945, + 289, + 792, + 332, + 260, + 3666, + 1276, + 290, + 5735, + 10625, + 3181, + ] + ], + # chatglm-6b + [ + [ + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + 130004, + 6723, + 626, + 26020, + 19, + 10806, + 266, + 14, + 102, + 1204, + 1784, + 27817, + 19, + 27798, + 14, + 118, + 972, + 27817, + 2055, + 109, + 2355, + 9187, + 100, + 1334, + 101, + 7319, + 19, + 9220, + 234, + 14, + 103, + 179, + 108, + 104, + 1132, + 277, + 101, + 2576, + 6225, + 1785, + ] + ], + # bloom + [ + [ + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + 48, + 88, + 79, + 87, + 76, + 92, + 72, + 68, + 85, + 210, + 83, + 85, + 82, + 70, + 88, + 85, + 72, + 80, + 72, + 81, + 87, + 210, + 11, + 48, + 60, + 51, + 12, + 210, + 68, + 81, + 71, + 210, + 69, + 79, + 82, + 70, + 78, + 210, + 69, + ] + ], +] + +all_ppl = [ + # llama + 31361.590644223128, + 31361.590644223128, + 31362.757106912533, + 31361.62055298091, + # qwen + 155909.83795939674, + 155939.57823718787, + 155917.27249705535, + 155909.83795939674, + # chatglm3-6b + 64415.31959719674, + 64454.8934643284, + 64416.60966606845, + 64420.172847651804, + # chatglm-6b + 130540.64669131214, + 130573.01895270264, + 130539.15278071642, + 130538.4058318297, + # llama-alibi + 31369.517462860927, + # bloom-alibi + 251106.84487228873, +] + + +@parameterized_class( + [ + "model_name_or_path", + "strategy_type", + "strategy_name", + "inputs", + "positin_ids", + "labels", + "attention_mask", + "ppl", + ], + [ + [ + "__internal_testing__/micro-random-llama", + "embedding_strategies", + "RotaryEmbedding", + all_inputs[0], + all_position_ids[0], + all_labels[0], + all_attention_mask[0], + all_ppl[0], + ], + [ + "__internal_testing__/micro-random-llama", + "embedding_strategies", + "LinearScalingRotaryEmbedding", + all_inputs[0], + all_position_ids[0], + all_labels[0], + all_attention_mask[0], + all_ppl[1], + ], + [ + "__internal_testing__/micro-random-llama", + "embedding_strategies", + "NTKScalingRotaryEmbedding", + all_inputs[0], + all_position_ids[0], + all_labels[0], + all_attention_mask[0], + all_ppl[2], + ], + [ + "__internal_testing__/micro-random-llama", + "embedding_strategies", + "DynamicNTKScalingRotaryEmbedding", + all_inputs[0], + all_position_ids[0], + all_labels[0], + all_attention_mask[0], + all_ppl[3], + ], + [ + "__internal_testing__/tiny-new-random-qwen-7b", + "embedding_strategies", + "RotaryEmbedding", + all_inputs[1], + all_position_ids[1], + all_labels[1], + all_attention_mask[1], + all_ppl[4], + ], + [ + "__internal_testing__/tiny-new-random-qwen-7b", + "embedding_strategies", + "LinearScalingRotaryEmbedding", + all_inputs[1], + all_position_ids[1], + all_labels[1], + all_attention_mask[1], + all_ppl[5], + ], + [ + "__internal_testing__/tiny-new-random-qwen-7b", + "embedding_strategies", + "NTKScalingRotaryEmbedding", + all_inputs[1], + all_position_ids[1], + all_labels[1], + all_attention_mask[1], + all_ppl[6], + ], + [ + "__internal_testing__/tiny-new-random-qwen-7b", + "embedding_strategies", + "DynamicNTKScalingRotaryEmbedding", + all_inputs[1], + all_position_ids[1], + all_labels[1], + all_attention_mask[1], + all_ppl[7], + ], + [ + "__internal_testing__/tiny-new-random-chatglm3-6b", + "embedding_strategies", + "RotaryEmbedding", + all_inputs[2], + all_position_ids[2], + all_labels[2], + all_attention_mask[2], + all_ppl[8], + ], + [ + "__internal_testing__/tiny-new-random-chatglm3-6b", + "embedding_strategies", + "LinearScalingRotaryEmbedding", + all_inputs[2], + all_position_ids[2], + all_labels[2], + all_attention_mask[2], + all_ppl[9], + ], + [ + "__internal_testing__/tiny-new-random-chatglm3-6b", + "embedding_strategies", + "NTKScalingRotaryEmbedding", + all_inputs[2], + all_position_ids[2], + all_labels[2], + all_attention_mask[2], + all_ppl[10], + ], + [ + "__internal_testing__/tiny-new-random-chatglm3-6b", + "embedding_strategies", + "DynamicNTKScalingRotaryEmbedding", + all_inputs[2], + all_position_ids[2], + all_labels[2], + all_attention_mask[2], + all_ppl[11], + ], + [ + "__internal_testing__/tiny-new-random-chatglm-6b", + "embedding_strategies", + "RotaryEmbedding", + all_inputs[3], + all_position_ids[3], + all_labels[3], + all_attention_mask[3], + all_ppl[12], + ], + [ + "__internal_testing__/tiny-new-random-chatglm-6b", + "embedding_strategies", + "LinearScalingRotaryEmbedding", + all_inputs[3], + all_position_ids[3], + all_labels[3], + all_attention_mask[3], + all_ppl[13], + ], + [ + "__internal_testing__/tiny-new-random-chatglm-6b", + "embedding_strategies", + "NTKScalingRotaryEmbedding", + all_inputs[3], + all_position_ids[3], + all_labels[3], + all_attention_mask[3], + all_ppl[14], + ], + [ + "__internal_testing__/tiny-new-random-chatglm-6b", + "embedding_strategies", + "DynamicNTKScalingRotaryEmbedding", + all_inputs[3], + all_position_ids[3], + all_labels[3], + all_attention_mask[3], + all_ppl[15], + ], + [ + "__internal_testing__/micro-random-llama", + "attention_strategies", + "AttentionWithLinearBias", + all_inputs[0], + all_position_ids[0], + all_labels[0], + all_attention_mask[0], + all_ppl[16], + ], + [ + "__internal_testing__/tiny-random-bloom", + "attention_strategies", + "AttentionWithLinearBias", + all_inputs[4], + all_position_ids[4], + all_labels[4], + all_attention_mask[4], + all_ppl[17], + ], + ], +) +class TestLongSequenceStrategiesTest(LLMTest, unittest.TestCase): + config_path: str = "./tests/fixtures/llm/predictor.yaml" + root_path = "" + + def setUp(self) -> None: + super().setUp() + sys.path.insert(0, "./llm") + + def disable_static(self): + paddle.utils.unique_name.switch() + paddle.disable_static() + + def get_model(self, model_name_or_path): + model_config = AutoConfig.from_pretrained(model_name_or_path) + if self.strategy_type == "embedding_strategies": + model_config.alibi = False + else: + model_config.alibi = True + model_config.use_long_sequence_strategies = True + model_config.long_sequence_strategy_type = self.strategy_type + model_config.long_sequence_strategy_name = self.strategy_name + max_position_embeddings = 10 if self.strategy_name == "DynamicNTKScalingRotaryEmbedding" else 2048 + model_config.long_sequence_init_args = { + "dim": int(model_config.hidden_size / model_config.num_attention_heads), + "max_position_embeddings": max_position_embeddings, + "base": 10000, + "scaling_factor": 4, + } + if "chatglm" in model_name_or_path: + model_config.long_sequence_init_args["position_encoding_2d"] = True + model = AutoModelForCausalLM.from_pretrained(model_name_or_path, config=model_config, dtype="float32") + return model + + def test_long_sequence_strategies(self): + input_ids = paddle.to_tensor(self.inputs, dtype=paddle.int64) + position_ids = paddle.to_tensor(self.positin_ids, dtype=paddle.int64) + attention_mask = paddle.to_tensor(self.attention_mask, dtype=paddle.int64) + labels = paddle.to_tensor(self.labels, dtype=paddle.int64) + ppl = self.ppl + inputs = { + "input_ids": input_ids, + "position_ids": position_ids, + "labels": labels, + "attention_mask": attention_mask, + } + model = self.get_model(self.model_name_or_path) + + output = model(**inputs) + self.assertTrue( + np.allclose( + np.exp(output[0].item()), + ppl, + rtol=1e-2, + ) + ) + + def test_dynamic_to_static_inference(self): + + if ( + "qwen" not in self.model_name_or_path + and "chatglm-6b" not in self.model_name_or_path + and "bloom" not in self.model_name_or_path + ): + model = self.get_model(self.model_name_or_path) + save_path = os.path.join(self.output_dir, self.model_name_or_path) + model.save_pretrained(save_path) + tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path) + if "llama" in self.model_name_or_path: + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + tokenizer.save_pretrained(save_path) + + self.disable_static() + config = load_test_config(self.config_path, "inference-to-static") + config["output_path"] = self.inference_output_dir + config["model_name_or_path"] = save_path + + with argv_context_guard(config): + from export_model import main + + main()