Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
… into release/2.8
  • Loading branch information
DesmonDay committed Sep 23, 2024
2 parents b5a9d26 + c39c08e commit 3ab7cd9
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 17 deletions.
20 changes: 18 additions & 2 deletions paddlenlp/trainer/argparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,17 @@
from enum import Enum
from inspect import isclass
from pathlib import Path
from typing import Any, Dict, Iterable, NewType, Optional, Tuple, Union, get_type_hints
from typing import (
Any,
Dict,
Iterable,
NewType,
Optional,
Tuple,
Union,
get_args,
get_type_hints,
)

DataClass = NewType("DataClass", Any)
DataClassType = NewType("DataClassType", Any)
Expand Down Expand Up @@ -129,7 +139,13 @@ def _parse_dataclass_field(parser: ArgumentParser, field: dataclasses.Field):
# This is the value that will get picked if we do --field_name (without value)
kwargs["const"] = True
elif isclass(origin_type) and issubclass(origin_type, list):
kwargs["type"] = field.type.__args__[0]
# supprt one dimension list and two dimension list
if hasattr(get_args(field.type)[0], "__args__"):
kwargs["type"] = field.type.__args__[0].__args__[0]
kwargs["action"] = "append"
else:
kwargs["type"] = field.type.__args__[0]

kwargs["nargs"] = "+"
if field.default_factory is not dataclasses.MISSING:
kwargs["default"] = field.default_factory()
Expand Down
92 changes: 77 additions & 15 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@
PREFIX_CHECKPOINT_DIR,
EvalLoopOutput,
EvalPrediction,
IntervalStrategy,
IterableDatasetShard,
OptimizerNames,
PredictionOutput,
Expand All @@ -137,6 +138,7 @@
get_scheduler,
has_length,
set_seed,
should_skip_data,
speed_metrics,
)
from .training_args import TrainingArguments
Expand Down Expand Up @@ -274,9 +276,16 @@ def __init__(

# Seed must be set before instantiating the model when using model
set_seed(seed=self.args.seed)

self._skip_global_steps = 0 # total skip global steps
self._skip_steps_since_last_logged = 0 # skip steps since last logged
if model is None:
raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument")
logger.warning("Model is None.")
self.model = None
self.train_dataset = train_dataset
self.tokenizer = tokenizer
default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
self.data_collator = data_collator if data_collator is not None else default_collator
return

if self.args.to_static:
model = paddle.jit.to_static(model)
Expand Down Expand Up @@ -897,6 +906,7 @@ def _inner_training_loop(
step_control = 0 # used in loop control, reset to 0 after every step
self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)

step = -1
for step, inputs in enumerate(epoch_iterator):
if self.args.use_hybrid_parallel and self.args.sep_parallel_degree > 1:
inputs = split_inputs_sequence_dim(inputs)
Expand Down Expand Up @@ -929,6 +939,44 @@ def _inner_training_loop(
steps_trained_progress_bar.close()
steps_trained_progress_bar = None

if should_skip_data(self.state.global_step, self.args.skip_data_intervals):
# skip this step

if (step_control + 1) % self.args.gradient_accumulation_steps == 0 or (
# last step in epoch but step is always smaller than gradient_accumulation_steps
steps_in_epoch <= args.gradient_accumulation_steps
and (step + 1) == steps_in_epoch
):
# update current global step and skip step
self.state.global_step += 1
self._skip_global_steps += 1
self._skip_steps_since_last_logged += 1

self.state.epoch = epoch + (step + 1) / steps_in_epoch

if self.state.global_step == 1 and self.args.logging_first_step:
self.control.should_log = True
if (
self.args.logging_strategy == IntervalStrategy.STEPS
and self.state.global_step % self.args.logging_steps == 0
):
self.control.should_log = True

self.control.should_evaluate = False
self.control.should_save = False

# log loss and memeory usage
self._maybe_log_save_evaluate(tr_loss, model, epoch, ignore_keys_for_eval, inputs=inputs)
self._print_timer()
step_control = 0
else:
step_control += 1
if self.state.global_step >= self.state.max_steps:
break

self.timers and self.timers("read-data").start()
continue

if step_control % args.gradient_accumulation_steps == 0:
self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
self.timers and self.timers("forward-backward").start()
Expand Down Expand Up @@ -1146,7 +1194,13 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg):
)

self._total_loss_scalar += tr_loss.item()
train_loss = self._total_loss_scalar / self.state.global_step

# In case all steps were skipped, the total loss is set to 0.
if self.state.global_step == self._skip_global_steps:
logger.info("All steps were skipped, the total loss is set to 0.")
train_loss = 0.0
else:
train_loss = self._total_loss_scalar / (self.state.global_step - self._skip_global_steps)

metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps)

Expand Down Expand Up @@ -1261,14 +1315,19 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval,
if self.control.should_log:

logs: Dict[str, float] = {}

num_steps = self.state.global_step - self._globalstep_last_logged - self._skip_steps_since_last_logged
self._skip_steps_since_last_logged = 0
# all_gather + mean() to get average loss over all processes
tr_loss_scalar = self._get_item_from_loss(self._nested_gather(tr_loss).mean())

# reset tr_loss to zero
tr_loss.subtract_(tr_loss)
# set loss to zero if all steps are skipped since last log
if num_steps == 0:
logs["loss"] = 0.0
else:
logs["loss"] = round(tr_loss_scalar / num_steps, 8)

logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 8)
logs["learning_rate"] = float("{0:.3e}".format(self._get_learning_rate()))
logs["global_step"] = int(self.state.global_step)

Expand All @@ -1289,19 +1348,22 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval,
total_train_batch_size = (
self.args.train_batch_size * self.args.gradient_accumulation_steps * self.args.dataset_world_size
)
num_steps = self.state.global_step - self._globalstep_last_logged

seq_length = None
if getattr(self, "is_pretraining", False) and hasattr(self.model, "config"):
seq_length = getattr(self.model.config, "seq_length", None)
logs.update(
speed_metrics(
"interval",
self._globalstep_last_start_time,
num_samples=total_train_batch_size * num_steps,
num_steps=num_steps,
seq_length=seq_length,

# Do not log speed metrics if all steps are skipped since last log.
if num_steps > 0:
logs.update(
speed_metrics(
"interval",
self._globalstep_last_start_time,
num_samples=total_train_batch_size * num_steps,
num_steps=num_steps,
seq_length=seq_length,
)
)
)

self._total_loss_scalar += tr_loss_scalar
self._globalstep_last_logged = self.state.global_step
Expand Down Expand Up @@ -3152,7 +3214,7 @@ def _set_signature_columns_if_needed(self):
self._signature_columns += list(set(["label", "label_ids"] + self.label_names))

def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
if not self.args.remove_unused_columns:
if not self.args.remove_unused_columns or self.model is None:
return dataset
if self._signature_columns is None:
# Inspect model forward signature to keep only the arguments it accepts.
Expand Down
17 changes: 17 additions & 0 deletions paddlenlp/trainer/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1092,3 +1092,20 @@ def set_hyrbid_parallel_seed(basic_seed, dataset_rank, tp_rank, pp_rank=0):
tracker.add("global_seed", global_seed)
if "local_seed" not in tracker.states_ and local_seed not in tracker.seeds_:
tracker.add("local_seed", local_seed)


def should_skip_data(global_step, skip_data_intervals):
"""Whether to skip current step data"""

if skip_data_intervals is None:
return False
skip_flag = False
for interval in skip_data_intervals:
if len(interval) != 2 or interval[0] > interval[1] or interval[0] <= 0:
raise ValueError(f"Please check your skip interval {interval}")
start_global_step, end_global_step = interval[0], interval[1]
# start_global_step and end_global_step start from 1, while global_step start from 0
if start_global_step <= global_step + 1 <= end_global_step:
skip_flag = True
break
return skip_flag
4 changes: 4 additions & 0 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,6 +797,10 @@ class TrainingArguments:
release_grads: Optional[bool] = field(
default=False, metadata={"help": "Whether to release gradients during training. Default is `False`."}
)
skip_data_intervals: Optional[List[List[int]]] = field(
default=None,
metadata={"help": "The intervals to skip, pass start global step and end global step at each interval"},
)

def __post_init__(self):
env_local_rank = int(os.environ.get("PADDLE_RANK_IN_NODE", -1))
Expand Down

0 comments on commit 3ab7cd9

Please sign in to comment.