Skip to content

Commit

Permalink
Remove shift_lm_labels attributes from all models and scripts (Padd…
Browse files Browse the repository at this point in the history
…lePaddle#6675)

* test 2.5.1

* turn off more tests

* turn off taskflow tests

* changes

* changes

* changes
  • Loading branch information
sijunhe authored Aug 10, 2023
1 parent 40bbcb7 commit f59274a
Show file tree
Hide file tree
Showing 29 changed files with 80 additions and 141 deletions.
1 change: 0 additions & 1 deletion llm/bloom/finetune_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,6 @@ def main():
dtype=dtype, # todo enable set dtype to avoid additional mem usage
tensor_parallel_degree=training_args.tensor_parallel_degree,
tensor_parallel_rank=training_args.tensor_parallel_rank,
lm_shift_labels=False,
)

if model_args.lora:
Expand Down
5 changes: 1 addition & 4 deletions llm/bloom/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,7 @@ def prediction_step(
loss, logits, labels = super().prediction_step(model, inputs, prediction_loss_only, ignore_keys)
# argmax here to avoid gather all logits, which is too memory-consuming.
# keepdim in order to maintain the same shape as logits
if model.config.lm_shift_labels:
return (loss, logits[0][..., :-1, :].argmax(axis=-1, keepdim=True), labels[..., 1:])
else:
return (loss, logits[0].argmax(axis=-1, keepdim=True), labels)
return (loss, logits[0].argmax(axis=-1, keepdim=True), labels)

model.eval()
with paddle.no_grad():
Expand Down
5 changes: 1 addition & 4 deletions llm/causallm/finetune_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,7 @@ def main():
tensor_parallel_rank=training_args.tensor_parallel_rank,
dtype=dtype,
)
# Alreday shift label & logit in convert example
# lm_shift_labels should be set before model initilization for some models(ex. llama)
if hasattr(model_config, "lm_shift_labels"):
model_config.lm_shift_labels = False

if hasattr(model_config, "use_flash_attention"):
model_config.use_flash_attention = model_args.use_flash_attention
if hasattr(model_config, "max_position_embeddings"):
Expand Down
2 changes: 0 additions & 2 deletions llm/chatglm/finetune_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,6 @@ def main():
else:
multi_query_group_num = None
attention_mask_pad_fn = None
# If ChatGLM, set lm_shift_labels to False
model.config.lm_shift_labels = False

if model_args.prefix_tuning:
prefix_config = PrefixConfig(
Expand Down
5 changes: 1 addition & 4 deletions llm/chatglm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,7 @@ def prediction_step(
loss, logits, labels = super().prediction_step(model, inputs, prediction_loss_only, ignore_keys)
# argmax here to avoid gather all logits, which is too memory-consuming.
# keepdim in order to maintain the same shape as logits
if hasattr(model.config, "lm_shift_labels") and model.config.lm_shift_labels:
return (loss, logits[0][..., :-1, :].argmax(axis=-1, keepdim=True), labels[..., 1:])
else:
return (loss, logits[0].argmax(axis=-1, keepdim=True), labels)
return (loss, logits[0].argmax(axis=-1, keepdim=True), labels)

loss = None

Expand Down
1 change: 0 additions & 1 deletion llm/ernie-3.5-se/run_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,6 @@ def main():
config.vocab_size = max(config.vocab_size, ((tokenizer.vocab_size - 1) // 128 + 1) * 128)
logger.info(f"Reset vocab size to {config.vocab_size} for batter amp peformance.")

config.lm_shift_labels = False
config.use_flash_attention = model_args.use_flash_attention
config.fuse_ln = model_args.use_fused_ln
config.fuse_attention_qkv = model_args.fuse_attention_qkv
Expand Down
5 changes: 1 addition & 4 deletions llm/ernie-3.5-se/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,7 @@ def prediction_step(
loss, logits, labels = super().prediction_step(model, inputs, prediction_loss_only, ignore_keys)
# argmax here to avoid gather all logits, which is too memory-consuming.
# keepdim in order to maintain the same shape as logits
if model.config.lm_shift_labels:
return (loss, logits[..., :-1, :].argmax(axis=-1, keepdim=True), labels[..., 1:])
else:
return (loss, logits.argmax(axis=-1, keepdim=True), labels)
return (loss, logits.argmax(axis=-1, keepdim=True), labels)

model.eval()

Expand Down
4 changes: 1 addition & 3 deletions llm/gpt-3/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,12 +278,11 @@ def __init__(
bos_token_id: int = 0,
eol_token_id: int = 3,
num_partitions: int = 1,
lm_shift_labels: bool = True,
normalize_before: bool = True,
recompute_granularity: str = "full",
scale_qk_coeff: float = 1.0,
tensor_parallel_degree: int = 1,
tensor_parallel_output:bool = True,
tensor_parallel_output: bool = True,
output_attentions: bool = False,
ignore_index: int = 0,
use_flash_attention: bool = False,
Expand Down Expand Up @@ -312,7 +311,6 @@ def __init__(
self.bos_token_id = bos_token_id
self.eol_token_id = eol_token_id
self.num_partitions = num_partitions
self.lm_shift_labels = lm_shift_labels
self.normalize_before = normalize_before
self.recompute_granularity = recompute_granularity
self.scale_qk_coeff = scale_qk_coeff
Expand Down
1 change: 0 additions & 1 deletion llm/gpt-3/finetune_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,6 @@ def main():
config.fuse_attention_qkv = model_args.fuse_attention_qkv
config.use_flash_attn = model_args.use_flash_attn
config.use_recompute = training_args.recompute
config.lm_shift_labels = True

config.tensor_parallel_degree = training_args.tensor_parallel_degree
config.tensor_parallel_rank = training_args.tensor_parallel_rank
Expand Down
79 changes: 56 additions & 23 deletions llm/gpt-3/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
from paddle.distributed.fleet.utils import recompute
from paddle.fluid import layers
from paddle.nn.layer.transformer import _convert_param_attr_to_list

from paddlenlp.transformers import PretrainedModel, register_base_model
from paddlenlp.transformers.model_outputs import CausalLMOutputWithCrossAttentions
Expand Down Expand Up @@ -96,7 +95,10 @@ class MultiHeadAttention(nn.Layer):
Cache = collections.namedtuple("Cache", ["k", "v"])
StaticCache = collections.namedtuple("StaticCache", ["k", "v"])

def __init__(self, config,):
def __init__(
self,
config,
):
super(MultiHeadAttention, self).__init__()

self.config = config
Expand All @@ -107,7 +109,9 @@ def __init__(self, config,):
self.use_flash_attention = config.use_flash_attention if flash_attention else None

self.head_dim = config.hidden_size // config.num_attention_heads
assert self.head_dim * config.num_attention_heads == config.hidden_size, "hidden_size must be divisible by num_attention_heads"
assert (
self.head_dim * config.num_attention_heads == config.hidden_size
), "hidden_size must be divisible by num_attention_heads"

self.num_attention_heads = config.num_attention_heads # default, without tensor parallel
if config.tensor_parallel_degree > 1:
Expand Down Expand Up @@ -246,7 +250,13 @@ def gen_cache(self, key, value=None, type=Cache):

def _flash_attention(self, q, k, v, attn_mask=None, output_attentions=False):
out, weights = flash_attention(
q, k, v, self.config.hidden_dropout_prob, causal=True, return_softmax=output_attentions, training=self.training
q,
k,
v,
self.config.hidden_dropout_prob,
causal=True,
return_softmax=output_attentions,
training=self.training,
)
out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])
return (out, weights) if output_attentions else out
Expand Down Expand Up @@ -278,9 +288,13 @@ def core_attn(self, q, k, v, attn_mask=None, output_attentions=False):
if self.config.hidden_dropout_prob:
if self.training:
with get_rng_state_tracker().rng_state("local_seed"):
weights = F.dropout(weights, self.config.hidden_dropout_prob, training=self.training, mode="upscale_in_train")
weights = F.dropout(
weights, self.config.hidden_dropout_prob, training=self.training, mode="upscale_in_train"
)
else:
weights = F.dropout(weights, self.config.hidden_dropout_prob, training=self.training, mode="upscale_in_train")
weights = F.dropout(
weights, self.config.hidden_dropout_prob, training=self.training, mode="upscale_in_train"
)

out = paddle.matmul(weights, v)

Expand Down Expand Up @@ -346,7 +360,9 @@ def __init__(
# Recompute defaults to False and is controlled by Trainer
self.enable_recompute = False

def forward(self, tgt, tgt_mask=None, memory=None, memory_mask=None, use_cache=False, cache=None, output_attentions=False):
def forward(
self, tgt, tgt_mask=None, memory=None, memory_mask=None, use_cache=False, cache=None, output_attentions=False
):
r"""
Applies a stack of N Transformer decoder layers on inputs. If `norm` is
provided, also applies layer normalization on the output of last decoder
Expand All @@ -359,17 +375,33 @@ def forward(self, tgt, tgt_mask=None, memory=None, memory_mask=None, use_cache=F
for i, mod in enumerate(self.layers):
if cache is None:
if use_cache:
output, new_cache = mod(output, tgt_mask=tgt_mask, memory=memory, use_cache=use_cache, cache=cache, output_attentions=output_attentions)
output, new_cache = mod(
output,
tgt_mask=tgt_mask,
memory=memory,
use_cache=use_cache,
cache=cache,
output_attentions=output_attentions,
)
new_caches.append(new_cache)
else:
has_gradient = not output.stop_gradient
if self.enable_recompute and self.config.recompute_granularity == "full" and has_gradient:
output = recompute(mod, output, tgt_mask, memory, use_cache, cache, output_attentions, use_reentrant=False)
output = recompute(
mod, output, tgt_mask, memory, use_cache, cache, output_attentions, use_reentrant=False
)
else:
output = mod(output, tgt_mask, memory, use_cache, cache, output_attentions)

else:
output, new_cache = mod(output, tgt_mask=tgt_mask, memory=memory, use_cache=use_cache, cache=cache[i], output_attentions=output_attentions)
output, new_cache = mod(
output,
tgt_mask=tgt_mask,
memory=memory,
use_cache=use_cache,
cache=cache[i],
output_attentions=output_attentions,
)
new_caches.append(new_cache)

if output_attentions:
Expand Down Expand Up @@ -410,12 +442,12 @@ class TransformerDecoderLayer(nn.Layer):
def __init__(self, config: GPTConfig):

super(TransformerDecoderLayer, self).__init__()

self.config = config

# Recompute defaults to False and is controlled by Trainer
self.enable_recompute = False

if not FusedDropoutAdd:
config.use_fused_dropout_add = False

Expand All @@ -442,7 +474,7 @@ def __init__(self, config: GPTConfig):

self.norm1 = nn.LayerNorm(config.hidden_size, epsilon=1e-5)
self.norm2 = nn.LayerNorm(config.hidden_size, epsilon=1e-5)

if not config.use_fused_dropout_add:
self.dropout1 = nn.Dropout(config.hidden_dropout_prob, mode="upscale_in_train")
self.dropout2 = nn.Dropout(config.hidden_dropout_prob, mode="upscale_in_train")
Expand All @@ -461,7 +493,9 @@ def forward(self, tgt, tgt_mask=None, memory=None, use_cache=False, cache=None,
if use_cache is False:
has_gradient = not tgt.stop_gradient
if self.enable_recompute and self.config.recompute_granularity == "full_attn" and has_gradient:
tgt = recompute(self.self_attn, tgt, None, None, tgt_mask, use_cache, cache, output_attentions, use_reentrant=False)
tgt = recompute(
self.self_attn, tgt, None, None, tgt_mask, use_cache, cache, output_attentions, use_reentrant=False
)
else:
tgt = self.self_attn(tgt, tgt, tgt, tgt_mask, use_cache, cache, output_attentions)
else:
Expand Down Expand Up @@ -519,7 +553,10 @@ class GPTEmbeddings(nn.Layer):
Include embeddings from word, position and token_type embeddings
"""

def __init__(self, config,):
def __init__(
self,
config,
):
super(GPTEmbeddings, self).__init__()

self.config = config
Expand Down Expand Up @@ -674,7 +711,9 @@ def __init__(self, config: GPTConfig):
decoder_layers,
)

def forward(self, input_ids, position_ids=None, attention_mask=None, use_cache=False, cache=None, output_attentions=False):
def forward(
self, input_ids, position_ids=None, attention_mask=None, use_cache=False, cache=None, output_attentions=False
):
if position_ids is None:
past_length = 0
if cache is not None:
Expand Down Expand Up @@ -728,12 +767,6 @@ def __init__(self, config):
self.loss_func = paddle.nn.CrossEntropyLoss(reduction="none", ignore_index=config.ignore_index)

def forward(self, prediction_scores, masked_lm_labels, loss_mask=None):

if self.config.lm_shift_labels:
# Shift so that tokens < n predict n
prediction_scores = prediction_scores[..., :-1, :]
masked_lm_labels = masked_lm_labels[..., 1:]

with paddle.amp.auto_cast(False):
masked_lm_loss = self.loss_func(prediction_scores.astype("float32"), masked_lm_labels.unsqueeze(2))
masked_lm_loss = masked_lm_loss[masked_lm_loss > 0].astype("float32")
Expand Down
1 change: 0 additions & 1 deletion llm/gpt-3/run_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,6 @@ def main():
config.fuse_attention_qkv = model_args.fuse_attention_qkv
config.use_recompute = training_args.recompute
config.use_flash_attention = model_args.use_flash_attention
config.lm_shift_labels = False

config.tensor_parallel_degree = training_args.tensor_parallel_degree
config.tensor_parallel_rank = training_args.tensor_parallel_rank
Expand Down
17 changes: 9 additions & 8 deletions llm/gpt-3/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,18 +197,18 @@ def convert_example(

outputs = tokenizer(
output_seq,
max_seq_len=max_target_length,
max_length=max_target_length,
# pad_to_max_seq_len=True,
truncation_strategy="longest_first",
return_attention_mask=True,
return_attention_mask=False,
return_token_type_ids=False,
)
inputs = tokenizer(
input_seq,
max_seq_len=max_source_length,
max_length=max_source_length,
# pad_to_max_seq_len=True,
truncation_strategy="longest_first",
return_attention_mask=True,
return_attention_mask=False,
return_length=False,
)

Expand All @@ -217,6 +217,10 @@ def convert_example(
final[k] = inputs[k] + outputs[k]
if k == "input_ids":
final["labels"] = [tokenizer.pad_token_id] * len(inputs["input_ids"]) + outputs[k]

# shift inputs and labels
final["input_ids"] = final["input_ids"][:-1]
final["labels"] = final["labels"][1:]
return final


Expand Down Expand Up @@ -318,10 +322,7 @@ def prediction_step(
loss, logits, labels = super().prediction_step(model, inputs, prediction_loss_only, ignore_keys)
# argmax here to avoid gather all logits, which is too memory-consuming.
# keepdim in order to maintain the same shape as logits
if model.config.lm_shift_labels:
return (loss, logits[..., :-1, :].argmax(axis=-1, keepdim=True), labels[..., 1:])
else:
return (loss, logits.argmax(axis=-1, keepdim=True), labels)
return (loss, logits.argmax(axis=-1, keepdim=True), labels)

model.eval()

Expand Down
1 change: 0 additions & 1 deletion llm/llama/finetune_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,6 @@ def main():
tensor_parallel_rank=training_args.tensor_parallel_rank,
use_flash_attention=model_args.use_flash_attention,
dtype=dtype, # todo enable set dtype to avoid additional mem usage
lm_shift_labels=False,
)
if model_args.lora:
if model_args.lora_path is None:
Expand Down
1 change: 0 additions & 1 deletion llm/llama/run_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,6 @@ def main():
config.vocab_size = max(config.vocab_size, ((tokenizer.vocab_size - 1) // 128 + 1) * 128)
logger.info(f"Reset vocab size to {config.vocab_size} for batter amp peformance.")

config.lm_shift_labels = False
config.use_flash_attention = model_args.use_flash_attention
config.use_fused_rms_norm = model_args.use_fused_rms_norm
config.fuse_attention_qkv = model_args.fuse_attention_qkv
Expand Down
2 changes: 0 additions & 2 deletions llm/llama/tests/test_pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def test_pipeline_model(self):
num_attention_heads=32,
tensor_parallel_degree=tp_degree,
tensor_parallel_rank=hcg.get_model_parallel_rank(),
lm_shift_labels=True,
tensor_parallel_output=False,
# use_flash_attention=True,
)
Expand Down Expand Up @@ -100,7 +99,6 @@ def test_pipeline_model(self):

single_model = LlamaForCausalLM.from_pretrained(
model_name_or_path,
lm_shift_labels=True,
num_attention_heads=32,
tensor_parallel_output=False,
)
Expand Down
1 change: 0 additions & 1 deletion llm/llama/tests/test_sequence_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ def test_sequence_model(self):

config = LlamaConfig.from_pretrained(model_name_or_path)
config.seq_length = seq_len
config.lm_shift_labels = False
config.use_flash_attention = False
config.use_fused_rms_norm = False
config.fuse_attention_qkv = False
Expand Down
5 changes: 1 addition & 4 deletions llm/llama/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,7 @@ def prediction_step(
loss, logits, labels = super().prediction_step(model, inputs, prediction_loss_only, ignore_keys)
# argmax here to avoid gather all logits, which is too memory-consuming.
# keepdim in order to maintain the same shape as logits
if model.config.lm_shift_labels:
return (loss, logits[..., :-1, :].argmax(axis=-1, keepdim=True), labels[..., 1:])
else:
return (loss, logits.argmax(axis=-1, keepdim=True), labels)
return (loss, logits.argmax(axis=-1, keepdim=True), labels)

model.eval()

Expand Down
1 change: 0 additions & 1 deletion llm/opt/finetune_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ def main():
# dtype=dtype, # todo enable set dtype to avoid additional mem usage
tensor_parallel_degree=training_args.tensor_parallel_degree,
tensor_parallel_rank=training_args.tensor_parallel_rank,
lm_shift_labels=False,
)
if model_args.lora:
# TODO: hardcode parameters for now. Change after MergedLoRA is introduced
Expand Down
Loading

0 comments on commit f59274a

Please sign in to comment.