-
Notifications
You must be signed in to change notification settings - Fork 31.9k
Floating-point operations logging in trainer #6768
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
9ee591e
a49f2aa
b50d3e1
6818ed2
5324678
2636bb8
04e471b
f78de89
9e7c05a
8def613
52635d6
7b8c0ce
245df7c
70f919f
349e916
9cc578d
03fe015
fa43ae1
e7a249f
45f5fcb
ab49c08
69d2b1e
d796eef
c175142
1773dd6
4610852
fae5254
6f1b48c
304ebe8
8ec3ea6
4becfac
1aaaa19
eb9d328
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,5 @@ | ||
| import inspect | ||
| import json | ||
| import math | ||
| import os | ||
| import re | ||
|
|
@@ -40,6 +41,8 @@ | |
| TrainOutput, | ||
| default_compute_objective, | ||
| default_hp_space, | ||
| distributed_broadcast_scalars, | ||
| distributed_concat, | ||
| set_seed, | ||
| ) | ||
| from .training_args import TrainingArguments | ||
|
|
@@ -144,7 +147,7 @@ def __iter__(self): | |
| indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples] | ||
| assert ( | ||
| len(indices) == self.num_samples | ||
| ), f"Indices length {len(indices)} and and sample number {self.num_samples} mismatched" | ||
| ), f"Indices length {len(indices)} and sample number {self.num_samples} mismatched" | ||
|
|
||
| return iter(indices) | ||
|
|
||
|
|
@@ -239,6 +242,7 @@ def __init__( | |
| "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method." | ||
| ) | ||
| self.tb_writer = tb_writer | ||
| self.log_history = [] | ||
| if "prediction_loss_only" in kwargs: | ||
| warnings.warn( | ||
| "Passing `prediction_loss_only` as a keyword argument is deprecated and won't be possible in a future version. Use `args.prediction_loss_only` instead.", | ||
|
|
@@ -292,6 +296,7 @@ def __init__( | |
|
|
||
| self.global_step = None | ||
| self.epoch = None | ||
| self.total_flos = None | ||
| if self.args.fp16 and _use_native_amp: | ||
| self.scaler = torch.cuda.amp.GradScaler() | ||
| self.hp_search_backend = None | ||
|
|
@@ -468,7 +473,11 @@ def setup_wandb(self): | |
| logger.info( | ||
| 'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"' | ||
| ) | ||
| combined_dict = {**self.model.config.to_dict(), **self.args.to_sanitized_dict()} | ||
| try: | ||
| combined_dict = {**self.model.config.to_dict(), **self.args.to_sanitized_dict()} | ||
| except AttributeError: | ||
| # in case the model has no config | ||
| combined_dict = {**self.args.to_sanitized_dict()} | ||
|
Comment on lines
+479
to
+480
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there an example of a model without a configuration?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah yes, it's something @sgugger mentioned as well - when writing for |
||
| wandb.init( | ||
| project=os.getenv("WANDB_PROJECT", "huggingface"), config=combined_dict, name=self.args.run_name | ||
| ) | ||
|
|
@@ -638,13 +647,16 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D | |
|
|
||
| self.global_step = 0 | ||
| self.epoch = 0 | ||
| self.total_flos = 0 | ||
| epochs_trained = 0 | ||
| steps_trained_in_current_epoch = 0 | ||
| # Check if continuing training from a checkpoint | ||
| if model_path is not None: | ||
| # set global_step to global_step of last saved checkpoint from model path | ||
| try: | ||
| self.global_step = int(model_path.split("-")[-1].split(os.path.sep)[0]) | ||
| self.total_flos = getattr(model.config, "total_flos", 0) | ||
|
|
||
|
Comment on lines
+658
to
+659
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. wouldn't this fail if the model didn't have a config?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, the dummy test model doesn't go through it since it doesn't have a method to calculate flos so I didn't catch it! See above, I think we might have to decide whether we want to assume it has a config or not |
||
| epochs_trained = self.global_step // (len(train_dataloader) // self.args.gradient_accumulation_steps) | ||
| steps_trained_in_current_epoch = self.global_step % ( | ||
| len(train_dataloader) // self.args.gradient_accumulation_steps | ||
|
|
@@ -653,9 +665,11 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D | |
| logger.info(" Continuing training from checkpoint, will skip to saved global_step") | ||
| logger.info(" Continuing training from epoch %d", epochs_trained) | ||
| logger.info(" Continuing training from global step %d", self.global_step) | ||
| logger.info(" Continuing training from %d non-embedding floating-point operations", self.total_flos) | ||
| logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) | ||
| except ValueError: | ||
| self.global_step = 0 | ||
| self.total_flos = 0 | ||
| logger.info(" Starting fine-tuning.") | ||
|
|
||
| tr_loss = 0.0 | ||
|
|
@@ -689,6 +703,7 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D | |
| continue | ||
|
|
||
| tr_loss += self.training_step(model, inputs) | ||
| self.total_flos += self.floating_point_ops(inputs) | ||
|
|
||
| if (step + 1) % self.args.gradient_accumulation_steps == 0 or ( | ||
| # last step in epoch but step is always smaller than gradient_accumulation_steps | ||
|
|
@@ -758,7 +773,7 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D | |
| self.save_model(output_dir) | ||
|
|
||
| if self.is_world_process_zero(): | ||
| self._rotate_checkpoints() | ||
| self._rotate_checkpoints(use_mtime=True) | ||
|
|
||
| if is_torch_tpu_available(): | ||
| xm.rendezvous("saving_optimizer_states") | ||
|
|
@@ -927,6 +942,13 @@ def log(self, logs: Dict[str, float], iterator: Optional[tqdm] = None) -> None: | |
|
|
||
| if self.epoch is not None: | ||
| logs["epoch"] = self.epoch | ||
| if self.total_flos is not None: | ||
| if self.args.local_rank != -1: | ||
| total_flos = distributed_broadcast_scalars([self.total_flos]).sum().item() | ||
| else: | ||
| total_flos = self.total_flos | ||
| if total_flos > 0: | ||
| logs["total_flos"] = self.total_flos | ||
| if self.global_step is None: | ||
| # when logging evaluation metrics without training | ||
| self.global_step = 0 | ||
|
|
@@ -954,6 +976,8 @@ def log(self, logs: Dict[str, float], iterator: Optional[tqdm] = None) -> None: | |
| if experiment is not None: | ||
| experiment._log_metrics(logs, step=self.global_step, epoch=self.epoch, framework="transformers") | ||
| output = {**logs, **{"step": self.global_step}} | ||
| if self.is_world_process_zero(): | ||
| self.log_history.append(output) | ||
| if iterator is not None: | ||
| iterator.write(output) | ||
| else: | ||
|
|
@@ -1092,13 +1116,17 @@ def _save_tpu(self, output_dir: Optional[str] = None): | |
| if xm.is_master_ordinal(): | ||
| os.makedirs(output_dir, exist_ok=True) | ||
| torch.save(self.args, os.path.join(output_dir, "training_args.bin")) | ||
| json.dump( | ||
| self.log_history, open(os.path.join(output_dir, "log_history.json"), "w"), indent=2, ensure_ascii=False | ||
| ) | ||
|
|
||
| # Save a trained model and configuration using `save_pretrained()`. | ||
| # They can then be reloaded using `from_pretrained()` | ||
| if not isinstance(self.model, PreTrainedModel): | ||
| raise ValueError("Trainer.model appears to not be a PreTrainedModel") | ||
|
|
||
| xm.rendezvous("saving_checkpoint") | ||
| self._store_flos() | ||
| self.model.save_pretrained(output_dir) | ||
| if self.tokenizer is not None: | ||
| self.tokenizer.save_pretrained(output_dir) | ||
|
|
@@ -1111,12 +1139,26 @@ def _save(self, output_dir: Optional[str] = None): | |
| # They can then be reloaded using `from_pretrained()` | ||
| if not isinstance(self.model, PreTrainedModel): | ||
| raise ValueError("Trainer.model appears to not be a PreTrainedModel") | ||
| self._store_flos() | ||
| self.model.save_pretrained(output_dir) | ||
| if self.tokenizer is not None: | ||
| self.tokenizer.save_pretrained(output_dir) | ||
|
|
||
| # Good practice: save your training arguments together with the trained model | ||
| torch.save(self.args, os.path.join(output_dir, "training_args.bin")) | ||
| json.dump( | ||
| self.log_history, open(os.path.join(output_dir, "log_history.json"), "w"), indent=2, ensure_ascii=False | ||
| ) | ||
|
|
||
| def _store_flos(self): | ||
| # Storing the number of floating-point operations that went into the model | ||
| if self.total_flos is not None: | ||
| if self.args.local_rank != -1: | ||
| total_flos = distributed_broadcast_scalars([self.total_flos]).sum().item() | ||
| else: | ||
| total_flos = self.total_flos | ||
| if total_flos > 0: | ||
| self.model.config.total_flos = total_flos | ||
|
|
||
| def _sorted_checkpoints(self, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False) -> List[str]: | ||
| ordering_and_checkpoint_path = [] | ||
|
|
@@ -1248,13 +1290,11 @@ def prediction_loop( | |
| self._past = None | ||
|
|
||
| disable_tqdm = not self.is_local_process_zero() or self.args.disable_tqdm | ||
| samples_count = 0 | ||
| for inputs in tqdm(dataloader, desc=description, disable=disable_tqdm): | ||
| loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only) | ||
| batch_size = inputs[list(inputs.keys())[0]].shape[0] | ||
| samples_count += batch_size | ||
| if loss is not None: | ||
| eval_losses.append(loss * batch_size) | ||
| eval_losses.extend([loss] * batch_size) | ||
| if logits is not None: | ||
| preds = logits if preds is None else torch.cat((preds, logits), dim=0) | ||
| if labels is not None: | ||
|
|
@@ -1267,9 +1307,9 @@ def prediction_loop( | |
| if self.args.local_rank != -1: | ||
| # In distributed mode, concatenate all results from all nodes: | ||
| if preds is not None: | ||
| preds = self.distributed_concat(preds, num_total_examples=self.num_examples(dataloader)) | ||
| preds = distributed_concat(preds, num_total_examples=self.num_examples(dataloader)) | ||
| if label_ids is not None: | ||
| label_ids = self.distributed_concat(label_ids, num_total_examples=self.num_examples(dataloader)) | ||
| label_ids = distributed_concat(label_ids, num_total_examples=self.num_examples(dataloader)) | ||
| elif is_torch_tpu_available(): | ||
| # tpu-comment: Get all predictions and labels from all worker shards of eval dataset | ||
| if preds is not None: | ||
|
|
@@ -1288,7 +1328,14 @@ def prediction_loop( | |
| else: | ||
| metrics = {} | ||
| if len(eval_losses) > 0: | ||
| metrics["eval_loss"] = np.sum(eval_losses) / samples_count | ||
| if self.args.local_rank != -1: | ||
| metrics["eval_loss"] = ( | ||
| distributed_broadcast_scalars(eval_losses, num_total_examples=self.num_examples(dataloader)) | ||
| .mean() | ||
| .item() | ||
| ) | ||
| else: | ||
| metrics["eval_loss"] = np.mean(eval_losses) | ||
|
|
||
| # Prefix all keys with eval_ | ||
| for key in list(metrics.keys()): | ||
|
|
@@ -1297,18 +1344,6 @@ def prediction_loop( | |
|
|
||
| return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics) | ||
|
|
||
| def distributed_concat(self, tensor: torch.Tensor, num_total_examples: int) -> torch.Tensor: | ||
| assert self.args.local_rank != -1 | ||
|
|
||
| output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())] | ||
| torch.distributed.all_gather(output_tensors, tensor) | ||
|
|
||
| concat = torch.cat(output_tensors, dim=0) | ||
|
|
||
| # truncate the dummy elements added by SequentialDistributedSampler | ||
| output = concat[:num_total_examples] | ||
| return output | ||
|
|
||
| def prediction_step( | ||
| self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool | ||
| ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: | ||
|
|
@@ -1354,3 +1389,32 @@ def prediction_step( | |
| if labels is not None: | ||
| labels = labels.detach() | ||
| return (loss, logits.detach(), labels) | ||
|
|
||
| def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]): | ||
| """ | ||
| For models that inherit from :class:`~transformers.PretrainedModel`, uses | ||
| that method to compute the number of floating point operations for every backward + forward pass. If using | ||
| another model, either implement such a method in the model or subclass and override this method. | ||
|
|
||
| Args: | ||
| model (:obj:`nn.Module`): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can't we use
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yep, changed it, allows us to save a few lines in the main method too
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can remove the docstring as well |
||
| The model to evaluate. | ||
| inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`): | ||
| The inputs and targets of the model. | ||
|
|
||
| Returns: | ||
| :obj:`int`: The number of floating-point operations. | ||
| """ | ||
|
|
||
| if isinstance(self.model, torch.nn.DataParallel) or isinstance( | ||
| self.model, torch.nn.parallel.DistributedDataParallel | ||
| ): | ||
| model = self.model.module | ||
| else: | ||
| model = self.model | ||
|
|
||
| if hasattr(model, "floating_point_ops"): | ||
| return model.floating_point_ops(inputs) | ||
|
|
||
| else: | ||
| return 0 | ||
Uh oh!
There was an error while loading. Please reload this page.