diff --git a/oslo/transformers/trainer.py b/oslo/transformers/trainer.py index 85a83fab..e3d46449 100644 --- a/oslo/transformers/trainer.py +++ b/oslo/transformers/trainer.py @@ -8,7 +8,8 @@ import random import warnings from packaging import version -from typing import Any, Dict, List, Optional, Union, Type, Mapping, Tuple +from typing import Any, Dict, List, Optional, Union, Type, Mapping, Tuple, Callable +from tqdm.auto import tqdm import torch import torch.distributed as dist @@ -50,11 +51,14 @@ TrainOutput, EvalPrediction, PREFIX_CHECKPOINT_DIR, + get_last_checkpoint, ) import oslo from oslo.torch import ParallelMode -from oslo.torch.utils.extensions import save_pretrained as save_pretrained_oslo + +# from oslo.torch.utils.extensions import save_pretrained as save_oslo_pretrained +# from oslo.torch.utils.extensions import from_parallelized as from_oslo_parallelized from oslo.torch.nn.parallel import ( PipelineParallel, TensorParallel, @@ -67,13 +71,12 @@ from oslo.transformers.trainer_utils import OptimizerNames, log_dist from oslo.transformers.training_args import TrainingArguments -TRAINING_ARGS_NAME = "training_args.bin" TRAINER_STATE_NAME = "trainer_state.json" OPTIMIZER_NAME = "optimizer.pt" SCHEDULER_NAME = "scheduler.pt" SCALER_NAME = "scaler.pt" WEIGHTS_NAME = "pytorch_model.bin" -CONFIG_NAME = "config.yaml" +OSLO_CONFIG_NAME = "config.yaml" DEFAULT_CALLBACKS = [DefaultFlowCallback] DEFAULT_PROGRESS_CALLBACK = ProgressCallback @@ -84,11 +87,13 @@ def __init__( self, model: nn.Module = None, args: TrainingArguments = None, + load_args_from_saved: Optional[Union[str, bool]] = None, train_dataset: Optional[Dataset] = None, eval_dataset: Optional[Dataset] = None, data_collator: Optional[DataCollator] = None, tokenizer: Optional[PreTrainedTokenizerBase] = None, - # compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, + compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, + resume_from_checkpoint: Optional[Union[str, bool]] = None, ): if args is None: @@ -97,7 +102,17 @@ def __init__( log_dist(f"No `TrainingArguments` passed, using `output_dir={output_dir}`.") args = TrainingArguments(output_dir=output_dir) + if load_args_from_saved: + if resume_from_checkpoint and load_args_from_saved is not None: + load_args_path = ( + load_args_from_saved + if isinstance(load_args_from_saved, str) + else resume_from_checkpoint + ) + args.load_args(load_args_path) + self.args = args + self.resume_from_checkpoint = resume_from_checkpoint default_collator = ( default_data_collator @@ -116,34 +131,30 @@ def __init__( self.parallel_context = None self.model_wrappers = [] - self.label_smoother = None # TODO + self.label_smoother = None # TODO: label_smoother self.parallel_context, self.model_wrappers = ( args.parallel_context, args.model_wrappers, ) - if ( - len(self.model_wrappers) - > 0 - # or ((args.fp16_full_eval or args.bf16_full_eval) and not args.do_train) - ): + if len(self.model_wrappers) > 0: self.place_model_on_device = False else: self.place_model_on_device = True if self.place_model_on_device: - # log_dist(f"model device, args.device: {self.args.device}", rank=-1) kwargs = dict(device=self.args.device) model = model.to(**kwargs) + self.model = model + if args.save_on_each_node: self.should_save = self.is_local_process_zero() else: self.should_save = self.is_world_process_zero() - self.model = model - + self.compute_metrics = compute_metrics # Define and add callback default_callbacks = DEFAULT_CALLBACKS callbacks = default_callbacks @@ -164,8 +175,9 @@ def __init__( self.do_grad_scaling = False if args.fp16 or args.bf16: self.do_grad_scaling = True + # TODO: Set ShardedGradScaler when oslo feature of it is ready # self.scaler = ShardedGradScaler() - # TODO Label Smoother + # TODO: Label Smoother self.state = TrainerState( is_local_process_zero=self.is_local_process_zero(), @@ -184,23 +196,12 @@ def __init__( self.args, self.state, self.control ) - def train( - self, - resume_from_checkpoint: Optional[Union[str, bool]] = None, - ): + def train(self): resume_from_checkpoint = ( - None if not resume_from_checkpoint else resume_from_checkpoint + None if not self.resume_from_checkpoint else self.resume_from_checkpoint ) args = self.args - # TODO Load potential model checkpoint - # if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint: - # resume_from_checkpoint = get_last_checkpoint(args.output_dir) - # if resume_from_checkpoint is None: - # raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})") - # - # if resume_from_checkpoint is not None: - # self._load_from_checkpoint(resume_from_checkpoint) # Data loader and number of training steps train_dataloader = self.get_train_dataloader() @@ -250,13 +251,23 @@ def train( self.parallel_context, **self.args.oslo_config["activation_checkpointing"], ) - self.model = self._wrap_model(self.model_wrappers) + # Load potential model checkpoint + if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint: + resume_from_checkpoint = get_last_checkpoint(args.output_dir) + if resume_from_checkpoint is None: + raise ValueError( + f"No valid checkpoint found in output directory ({args.output_dir})" + ) + + if resume_from_checkpoint is not None: + self._load_from_checkpoint(resume_from_checkpoint) + self.create_optimizer_and_scheduler(num_training_steps=max_steps) - # # TODO Check if saved optimizer or scheduler states exist - # self._load_optimizer_and_scheduler(resume_from_checkpoint) + # Check if saved optimizer or scheduler states exist + self._load_optimizer_and_scheduler(resume_from_checkpoint) # Train! log_dist("***** Running training *****") @@ -274,33 +285,42 @@ def train( self.state.epoch = 0 start_time = time.time() epochs_trained = 0 - # steps_trained_in_current_epoch = 0 - # steps_trained_progress_bar = None - - # # TODO Check if continuing training from a checkpoint - # if resume_from_checkpoint is not None and os.path.isfile( - # os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME) - # ): - # self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)) - # epochs_trained = self.state.global_step // num_update_steps_per_epoch - # if not args.ignore_data_skip: - # steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch) - # steps_trained_in_current_epoch *= args.gradient_accumulation_steps - # else: - # steps_trained_in_current_epoch = 0 - # - # logger.info(" Continuing training from checkpoint, will skip to saved global_step") - # logger.info(f" Continuing training from epoch {epochs_trained}") - # logger.info(f" Continuing training from global step {self.state.global_step}") - # if not args.ignore_data_skip: - # logger.info( - # f" Will skip the first {epochs_trained} epochs then the first {steps_trained_in_current_epoch} " - # "batches in the first epoch. If this takes a lot of time, you can add the `--ignore_data_skip` " - # "flag to your launch command, but you will resume the training on data already seen by your model." - # ) - # if self.is_local_process_zero() and not args.disable_tqdm: - # steps_trained_progress_bar = tqdm(total=steps_trained_in_current_epoch) - # steps_trained_progress_bar.set_description("Skipping the first batches") + steps_trained_in_current_epoch = 0 + steps_trained_progress_bar = None + + if resume_from_checkpoint is not None and os.path.isfile( + os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME) + ): + self.state = TrainerState.load_from_json( + os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME) + ) + epochs_trained = self.state.global_step // num_update_steps_per_epoch + if not args.ignore_data_skip: + steps_trained_in_current_epoch = self.state.global_step % ( + num_update_steps_per_epoch + ) + steps_trained_in_current_epoch *= args.gradient_accumulation_steps + else: + steps_trained_in_current_epoch = 0 + + log_dist( + " Continuing training from checkpoint, will skip to saved global_step" + ) + log_dist(f" Continuing training from epoch {epochs_trained}") + log_dist(f" Continuing training from global step {self.state.global_step}") + if not args.ignore_data_skip: + log_dist( + f" Will skip the first {epochs_trained} epochs then the first {steps_trained_in_current_epoch} " + "batches in the first epoch. If this takes a lot of time, you can add the `--ignore_data_skip` " + "flag to your launch command, but you will resume the training on data already seen by your model." + ) + if self.is_local_process_zero(): + steps_trained_progress_bar = tqdm( + total=steps_trained_in_current_epoch + ) + steps_trained_progress_bar.set_description( + "Skipping the first batches" + ) # Update the references self.callback_handler.model = self.model @@ -322,10 +342,29 @@ def train( self._total_loss_scalar = 0.0 self._globalstep_last_logged = self.state.global_step self.optimizer.zero_grad() + self.model.zero_grad() self.control = self.callback_handler.on_train_begin( args, self.state, self.control ) + # Skip the first epochs_trained epochs to get the random state of the dataloader at the right point. + if not args.ignore_data_skip: + for epoch in range(epochs_trained): + is_random_sampler = hasattr(train_dataloader, "sampler") and isinstance( + train_dataloader.sampler, RandomSampler + ) + if ( + version.parse(torch.__version__) < version.parse("1.11") + or not is_random_sampler + ): + # We just need to begin an iteration to create the randomization of the sampler. + # That was before PyTorch 1.11 however... + for _ in train_dataloader: + break + else: + # Otherwise we need to call the whooooole sampler cause there is some random operation added + # AT THE VERY END! + _ = list(train_dataloader.sampler) for epoch in range(epochs_trained, num_train_epochs): if isinstance(train_dataloader, DataLoader) and isinstance( @@ -338,9 +377,7 @@ def train( train_dataloader.dataset.set_epoch(epoch) epoch_iterator = train_dataloader - # # TODO Reset the past mems state at the beginning of each epoch if necessary. - # if args.past_index >= 0: - # self._past = None + steps_in_epoch = ( len(epoch_iterator) if len_dataloader is not None @@ -352,17 +389,18 @@ def train( step = -1 for step, inputs in enumerate(epoch_iterator): - # # TODO Skip past any already trained steps if resuming training - # if steps_trained_in_current_epoch > 0: - # steps_trained_in_current_epoch -= 1 - # if steps_trained_progress_bar is not None: - # steps_trained_progress_bar.update(1) - # if steps_trained_in_current_epoch == 0: - # self._load_rng_state(resume_from_checkpoint) - # continue - # elif steps_trained_progress_bar is not None: - # steps_trained_progress_bar.close() - # steps_trained_progress_bar = None + + # Skip past any already trained steps if resuming training + if steps_trained_in_current_epoch > 0: + steps_trained_in_current_epoch -= 1 + if steps_trained_progress_bar is not None: + steps_trained_progress_bar.update(1) + if steps_trained_in_current_epoch == 0: + self._load_rng_state(resume_from_checkpoint) + continue + elif steps_trained_progress_bar is not None: + steps_trained_progress_bar.close() + steps_trained_progress_bar = None if step % args.gradient_accumulation_steps == 0: self.control = self.callback_handler.on_step_begin( args, self.state, self.control @@ -509,17 +547,39 @@ def training_step( return loss.detach() + def _load_from_checkpoint(self, resume_from_checkpoint): + log_dist(f"Loading model from {resume_from_checkpoint}.") + + if len(self.args.model_wrappers) > 0: + self.model.from_parallelized(resume_from_checkpoint) + # from_oslo_parallelized(self.model, resume_from_checkpoint) + # self.model.from_oslo_parallelized(resume_from_checkpoint) + # self.model.from_parallelized(resume_from_checkpoint) + else: + if not os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)): + raise ValueError( + f"Can't find a valid checkpoint at {resume_from_checkpoint}" + ) + state_dict = torch.load( + os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu" + ) + self._load_state_dict_in_model(state_dict) + del state_dict + def _load_best_model(self): log_dist( f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric})." ) - - best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME) + best_model_path = self.state.best_model_checkpoint if os.path.exists(best_model_path): # We load the model state dict on the CPU to avoid an OOM error. - state_dict = torch.load(best_model_path, map_location="cpu") - # If the model is on the GPU, it still works! - self._load_state_dict_in_model(state_dict) + if len(self.args.model_wrappers) > 0: + # from_oslo_parallelized(self.model, best_model_path) + self.model.from_parallelized(best_model_path) + else: + state_dict = torch.load(best_model_path, map_location="cpu") + # If the model is on the GPU, it still works! + self._load_state_dict_in_model(state_dict) else: log_dist( f"Could not locate the best model at {best_model_path}, if you are running a distributed training " @@ -527,8 +587,8 @@ def _load_best_model(self): ) def _load_state_dict_in_model(self, state_dict): - load_result = self.model.load_state_dict(state_dict, strict=False) + load_result = self.model.load_state_dict(state_dict, strict=False) if len(load_result.missing_keys) != 0: if self.model._keys_to_ignore_on_save is not None and set( load_result.missing_keys @@ -575,12 +635,48 @@ def _maybe_log_save_evaluate(self, tr_loss, model): metrics = self.evaluate() if self.control.should_save: - self._save_checkpoint(model, metrics=metrics) + self._save_checkpoint(metrics=metrics) self.control = self.callback_handler.on_save( self.args, self.state, self.control ) - def _save_checkpoint(self, model, metrics=None): + def _load_rng_state(self, checkpoint): + if checkpoint is None: + return + local_rank = self.parallel_context.get_local_rank(ParallelMode.GLOBAL) + if local_rank != -1: + rng_file = os.path.join(checkpoint, f"rng_state_{local_rank}.pth") + if not os.path.isfile(os.path.join(checkpoint, rng_file)): + log_dist( + f"Didn't find an RNG file for process {local_rank}, if you are resuming a training that " + "wasn't launched in a distributed fashion, reproducibility is not guaranteed." + ) + return + else: + rng_file = os.path.join(checkpoint, "rng_state.pth") + if not os.path.isfile(rng_file): + log_dist( + "Didn't find an RNG file, if you are resuming a training that was launched in a distributed " + "fashion, reproducibility is not guaranteed." + ) + return + checkpoint_rng_state = torch.load(rng_file) + random.setstate(checkpoint_rng_state["python"]) + np.random.set_state(checkpoint_rng_state["numpy"]) + torch.random.set_rng_state(checkpoint_rng_state["cpu"]) + if torch.cuda.is_available(): + if local_rank != -1: + torch.cuda.random.set_rng_state(checkpoint_rng_state["cuda"]) + else: + try: + torch.cuda.random.set_rng_state_all(checkpoint_rng_state["cuda"]) + except Exception as e: + log_dist( + f"Didn't manage to set back the RNG states of the GPU because of the following error:\n {e}" + "\nThis won't yield the same results as if the training had not been interrupted." + ) + + def _save_checkpoint(self, metrics=None): # Save model checkpoint checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" @@ -660,6 +756,31 @@ def _save_checkpoint(self, model, metrics=None): # if self.should_save: # self._rotate_checkpoints(use_mtime=True, output_dir=run_dir) + def _load_optimizer_and_scheduler(self, checkpoint): + """If optimizer and scheduler states exist, load them.""" + if checkpoint is None: + return + if os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME)) and os.path.isfile( + os.path.join(checkpoint, SCHEDULER_NAME) + ): + self.optimizer.load_state_dict( + torch.load( + os.path.join(checkpoint, OPTIMIZER_NAME), + map_location=self.args.device, + ) + ) + with warnings.catch_warnings(record=True) as caught_warnings: + self.lr_scheduler.load_state_dict( + torch.load(os.path.join(checkpoint, SCHEDULER_NAME)) + ) + reissue_pt_warnings(caught_warnings) + if self.do_grad_scaling and os.path.isfile( + os.path.join(checkpoint, SCALER_NAME) + ): + self.scaler.load_state_dict( + torch.load(os.path.join(checkpoint, SCALER_NAME)) + ) + def save_model( self, output_dir: Optional[str] = None, @@ -676,11 +797,9 @@ def save_model( output_dir = self.args.output_dir os.makedirs(output_dir, exist_ok=True) log_dist(f"Saving model checkpoint to {output_dir}") - log_dist(type(self.model)) - log_dist(self.args.model_wrappers) if len(self.args.model_wrappers) > 0: - save_pretrained_oslo(self.model, output_dir, state_dict=state_dict) + self.model.save_pretrained(output_dir, state_dict=state_dict) else: if not isinstance(self.model, PreTrainedModel): if isinstance(unwrap_model(self.model), PreTrainedModel): @@ -702,9 +821,7 @@ def save_model( if self.tokenizer is not None: self.tokenizer.save_pretrained(output_dir) - # TODO error => TypeError: cannot pickle 'torch._C._distributed_c10d.ProcessGroupNCCL' object - # # Good practice: save your training arguments together with the trained model - # torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) + self.args.save_args(output_dir) def log(self, logs: Dict[str, float]) -> None: """ @@ -731,7 +848,12 @@ def evaluate( # ignore_keys: Optional[List[str]] = None, # metric_key_prefix: str = 'eval', ) -> Dict[str, float]: + """ + Run evaluation and returns metrics. + The calling script will be responsible for providing a method to compute metrics, as they are task-dependent + (pass it to the init `compute_metrics` argument). + """ eval_dataloader = self.get_eval_dataloader(eval_dataset) start_time = time.time() output = self.evaluation_loop( @@ -739,10 +861,11 @@ def evaluate( description="Evaluation", # No point gathering the predictions if there are no metrics, otherwise we defer to # self.args.prediction_loss_only - # prediction_loss_only=True if self.compute_metrics is None else None, + prediction_loss_only=True if self.compute_metrics is None else None, # ignore_keys=ignore_keys, # metric_key_prefix=metric_key_prefix, ) + # log_dist(output.metrics) total_batch_size = ( self.args.eval_batch_size @@ -768,24 +891,24 @@ def evaluation_loop( self, dataloader: DataLoader, description: str, - # prediction_loss_only: Optional[bool] = None, + prediction_loss_only: Optional[bool] = None, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "eval", ) -> EvalLoopOutput: args = self.args - - # prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else False + prediction_loss_only = ( + prediction_loss_only if prediction_loss_only is not None else False + ) model = self._wrap_model(self.model_wrappers, training=False) batch_size = self.args.eval_batch_size - log_dist(f"***** Running {description} *****", rank=-1) + log_dist(f"***** Running {description} *****") if has_length(dataloader): - log_dist(f" Num examples = {self.num_examples(dataloader)}", rank=-1) + log_dist(f" Num examples = {self.num_examples(dataloader)}") else: - log_dist(" Num examples: Unknown", rank=-1) - - log_dist(f" Batch size = {batch_size}", rank=-1) + log_dist(" Num examples: Unknown") + log_dist(f" Batch size = {batch_size}") model.eval() @@ -824,7 +947,7 @@ def evaluation_loop( loss, logits, labels = self.prediction_step( model, inputs, - # prediction_loss_only, + prediction_loss_only, ignore_keys=ignore_keys, ) # inputs_decode = inputs["input_ids"] if args.include_inputs_for_metrics else None @@ -966,20 +1089,19 @@ def evaluation_loop( if all_inputs is not None: all_inputs = nested_truncate(all_inputs, num_samples) - # # Metrics! TODO - # if self.compute_metrics is not None and all_preds is not None and all_labels is not None: - # if args.include_inputs_for_metrics: - # metrics = self.compute_metrics( - # EvalPrediction(predictions=all_preds, - # label_ids=all_labels, - # inputs=all_inputs)) - # else: - # metrics = self.compute_metrics( - # EvalPrediction(predictions=all_preds, label_ids=all_labels)) - # else: - # metrics = {} - - metrics = {} + # Metrics! + if ( + self.compute_metrics is not None + and all_preds is not None + and all_labels is not None + ): + metrics = self.compute_metrics( + EvalPrediction( + predictions=all_preds, label_ids=all_labels, inputs=all_inputs + ) + ) + else: + metrics = {} # To be JSON-serializable, we need to remove numpy types or zero-d tensors metrics = denumpify_detensorize(metrics) @@ -1054,7 +1176,7 @@ def prediction_step( self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], - # prediction_loss_only: bool, + prediction_loss_only: bool, ignore_keys: Optional[List[str]] = None, ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: has_labels = all(inputs.get(k) is not None for k in self.label_names) @@ -1103,8 +1225,8 @@ def prediction_step( # if self.args.past_index >= 0: # self._past = outputs[self.args.past_index - 1] - # if prediction_loss_only: - # return (loss, None, None) + if prediction_loss_only: + return (loss, None, None) logits = nested_detach(logits) if len(logits) == 1: @@ -1121,7 +1243,7 @@ def _wrap_model(self, model_wrappers: List, training: bool = True): model = self.model # Distributed training (should be after apex fp16 initialization) - if self.parallel_context is not None: + if len(self.model_wrappers) > 0: for wrapper in model_wrappers: log_dist(f"Model wrapping with wrapper: {wrapper}") @@ -1329,7 +1451,7 @@ def create_scheduler( Args: num_training_steps (int): The number of training steps to do. """ - if self.parallel_context or self.lr_scheduler is None: + if self.lr_scheduler is None: from transformers import get_scheduler self.lr_scheduler = get_scheduler( @@ -1350,10 +1472,6 @@ def compute_loss(self, model, inputs, return_outputs=False): else: labels = None outputs = model(**inputs) - # # TODO: Save past state if it exists - # # HF-TODO: this needs to be fixed and made cleaner later. - # if self.args.past_index >= 0: - # self._past = outputs[self.args.past_index] if labels is not None: loss = self.label_smoother(outputs, labels) @@ -1417,9 +1535,6 @@ def _prepare_inputs( "The batch received was empty, your model won't be able to train on it. Double-check that your " "training dataset contains keys expected by the model." ) - # TODO mems - # if self.args.past_index >= 0 and self._past is not None: - # inputs["mems"] = self._past return inputs diff --git a/oslo/transformers/training_args.py b/oslo/transformers/training_args.py index e348bc60..862c1d32 100644 --- a/oslo/transformers/training_args.py +++ b/oslo/transformers/training_args.py @@ -2,6 +2,7 @@ import os from dataclasses import asdict, dataclass, field from typing import List, Optional, Union +import copy import torch from transformers.trainer_utils import SchedulerType, IntervalStrategy @@ -13,6 +14,8 @@ init_oslo_features, ) +TRAINING_ARGS_NAME = "training_args.bin" + @dataclass class TrainingArguments: @@ -103,6 +106,10 @@ class TrainingArguments: - `True` if `metric_for_best_model` is set to a value that isn't `"loss"` or `"eval_loss"`. - `False` if `metric_for_best_model` is not set, or set to `"loss"` or `"eval_loss"`. + ignore_data_skip (`bool`, *optional*, defaults to `False`): + When resuming training, whether or not to skip the epochs and batches to get the data loading at the same + stage as in the previous training. If set to `True`, the training will begin faster (as that skipping step + can take a long time) but will not yield the same results as the interrupted training would have. label_smoothing_factor (`float`, *optional*, defaults to 0.0): The label smoothing factor to use. Zero means no label smoothing, otherwise the underlying onehot-encoded labels are changed from 0s and 1s to `label_smoothing_factor/num_labels` and `1 - label_smoothing_factor + @@ -111,7 +118,6 @@ class TrainingArguments: The optimizer to use: adamw_hf, adamw_torch, adamw_apex_fused, or adafactor. gradient_checkpointing (`bool`, *optional*, defaults to `False`): If True, use gradient checkpointing to save memory at the expense of slower backward pass. - # TODO report_to (`str` or `List[str]`, *optional*, defaults to `"all"`): The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`, `"comet_ml"`, `"mlflow"`, `"tensorboard"` and `"wandb"`. Use `"all"` to report to all integrations @@ -119,7 +125,6 @@ class TrainingArguments: save_on_each_node (`bool`, *optional*, defaults to `False`): When doing multi-node distributed training, whether to save models and checkpoints on each node, or only on the main one. - This should not be activated when the different nodes use the same storage as the files will be saved with the same names for each node. bf16 (`bool`, *optional*, defaults to `False`): @@ -147,7 +152,7 @@ class TrainingArguments: }, ) evaluation_strategy: IntervalStrategy = field( - default="no", + default="steps", metadata={"help": "The evaluation strategy to use."}, ) per_device_train_batch_size: int = field( @@ -257,6 +262,12 @@ class TrainingArguments: "help": "Whether the `metric_for_best_model` should be maximized or not." }, ) + ignore_data_skip: bool = field( + default=False, + metadata={ + "help": "When resuming training, whether or not to skip the first epochs and batches to get to the same training data." + }, + ) label_smoothing_factor: float = field( default=0.0, metadata={ @@ -322,20 +333,11 @@ def __post_init__(self): "eval_loss", ] - self.oslo_config, self.parallel_context, self.model_wrappers = None, None, None - - if self.oslo_config_path_or_dict: - - # will be used later by the Trainer - self.oslo_config = OsloTrainerConfig(self.oslo_config_path_or_dict) - # logging.info(f"Oslo Config: {self.oslo_config}") - self.parallel_context, self.model_wrappers = init_oslo_features( - self.oslo_config - ) - else: - self.parallel_context, self.model_wrappers = init_oslo_features( - OsloTrainerConfig({}) - ) + ( + self.oslo_config, + self.parallel_context, + self.model_wrappers, + ) = self.set_oslo_config() def __str__(self): self_as_dict = { @@ -386,6 +388,31 @@ def device(self) -> torch.device: """ return torch.device("cuda" if torch.cuda.is_available() else "cpu") + def set_oslo_config(self): + if self.oslo_config_path_or_dict: + oslo_config = OsloTrainerConfig(self.oslo_config_path_or_dict) + else: + oslo_config = OsloTrainerConfig({}) + parallel_context, model_wrappers = init_oslo_features(oslo_config) + return oslo_config, parallel_context, model_wrappers + + def save_args(self, path): + _tmp_parallel_context = self.parallel_context + self.parallel_context = None + args = copy.deepcopy(self) + torch.save(args, os.path.join(path, TRAINING_ARGS_NAME)) + self.parallel_context = _tmp_parallel_context + + @classmethod + def load_args(cls, path): + args = torch.load(os.path.join(path, TRAINING_ARGS_NAME)) + ( + args.oslo_config, + args.parallel_context, + args.model_wrappers, + ) = args.set_oslo_config() + return args + def get_batch_size(per_device_batch_size, n_gpu) -> int: """ diff --git a/tests_deprecated/transformers/trainer/test_trainer_reload.py b/tests_deprecated/transformers/trainer/test_trainer_reload.py new file mode 100644 index 00000000..05d32057 --- /dev/null +++ b/tests_deprecated/transformers/trainer/test_trainer_reload.py @@ -0,0 +1,59 @@ +import logging +import torch +import os +from datasets import load_dataset +from transformers import BertTokenizer, BertForSequenceClassification + +from oslo.transformers.tasks.data_sequence_classification import ( + ProcessorForSequenceClassification, + DataCollatorForSequenceClassification, +) +from oslo.transformers.trainer import Trainer +from oslo.transformers.training_args import TrainingArguments + +logging.basicConfig(level=logging.INFO) + +model = BertForSequenceClassification.from_pretrained("bert-base-uncased") +tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") + +# 데이터셋 생성 +dataset = load_dataset("glue", "cola") +dataset = dataset.rename_column("sentence", "text") +dataset = dataset.rename_column("label", "labels") + +processor = ProcessorForSequenceClassification(tokenizer, 512) +if processor._tokenizer.pad_token is None: + processor._tokenizer.pad_token = processor._tokenizer.eos_token + +processed_dataset = dataset.map( + processor, batched=True, remove_columns=dataset["train"].column_names +) +processed_dataset.cleanup_cache_files() +train_dataset = processed_dataset["train"] +valid_dataset = processed_dataset["validation"] + +data_collator = DataCollatorForSequenceClassification(processor) + +# Define trainer arguments +reload_path = "output/checkpoint-500" +args = TrainingArguments.load_args(reload_path) + +# Define trainer +trainer = Trainer( + args=args, + model=model, + tokenizer=tokenizer, + train_dataset=train_dataset, + eval_dataset=valid_dataset, + data_collator=data_collator, + resume_from_checkpoint=reload_path, +) + +# Train +trainer.train() + +# # Save +# trainer.save_model() + +# Eval +metrics = trainer.evaluate(eval_dataset=valid_dataset) diff --git a/tests_deprecated/transformers/trainer/test_trainer_tp_1d.py b/tests_deprecated/transformers/trainer/test_trainer_tp_1d.py index 6825c265..71c591ea 100644 --- a/tests_deprecated/transformers/trainer/test_trainer_tp_1d.py +++ b/tests_deprecated/transformers/trainer/test_trainer_tp_1d.py @@ -11,6 +11,30 @@ from oslo.transformers.trainer import Trainer from oslo.transformers.training_args import TrainingArguments +from sklearn.metrics import ( + accuracy_score, + precision_recall_fscore_support, + roc_auc_score, +) + + +def compute_metrics(pred): + labels = pred.label_ids + preds = pred.predictions.argmax(-1) + precision, recall, f1, _ = precision_recall_fscore_support( + labels, preds, average="binary" + ) + acc = accuracy_score(labels, preds) + auc = roc_auc_score(labels, preds) + return { + "accuracy": acc, + "f1": f1, + "precision": precision, + "recall": recall, + "auroc": auc, + } + + logging.basicConfig(level=logging.INFO) os.environ["WANDB_DISABLED"] = "true" @@ -69,6 +93,7 @@ train_dataset=train_dataset, eval_dataset=valid_dataset, data_collator=data_collator, + compute_metrics=compute_metrics, ) trainer.train()