-
Notifications
You must be signed in to change notification settings - Fork 31.6k
Trainer callbacks #7596
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Trainer callbacks #7596
Conversation
| self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps | ||
| ) | ||
|
|
||
| def setup_wandb(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removing this method is the only real breaking change, but I doubt we have a lot of users that subclassed Trainer to implement a custom setup_wandb. For those users, the new way is to subclass WandbCallback and override the setup method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay with me
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two breaking changes: setup_comet() was removed as well. But that is okay with me too. Just wanted to point that out if this is listed in a set of "breaking changes" in the release notes.
| self._total_flos = self.state.total_flos | ||
| logging_loss_scalar = 0.0 | ||
| model.zero_grad() | ||
| disable_tqdm = self.args.disable_tqdm or not self.is_local_process_zero() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not commenting on all of those, but all references to progress bars disappear as it's all dealt with in ProgressCalback.
|
|
||
| return TrainOutput(self.state.global_step, tr_loss.item() / self.state.global_step) | ||
|
|
||
| def _save_training(self, model, trial, metrics=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Renamed private _save_training to _save_checkpoint as I think it's better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is very cool, I think it cleans up the Trainer class a lot. LGTM!
As discussed with @lhoestq and @julien-c, it would be nice to add in the documentation that these are Keas-style callbacks. When we hear callbacks, we might understand function callbacks, which are very different in practice.
EDIT: After discussing it with Sylvain, the point made above doesn't need to be applied as it isn't as ambiguous as I thought it would be. Citing Sylvain: it's callback as in events, similarly to what is done in VisualBasic or Delphi: on_click/on_touch etc.
src/transformers/integrations.py
Outdated
| """ | ||
|
|
||
| def __init__(self): | ||
| assert _has_comet, "CometCallback requires wandb to be installed. Run `pip install comet-ml`." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think comet requires wandb to be installed :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, good catch!
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should add the description of the file here
src/transformers/trainer.py
Outdated
| :class:`~transformers.Trainer` is optimized to work with the :class:`~transformers.PreTrainedModel` | ||
| provided by the library. You can still use your own models defined as :obj:`torch.nn.Module` as long as | ||
| they work the same way as the 🤗 Transformers models (expect input and labels as arguments and return | ||
| a tuple with the loss first if labels are provided). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Imo it would be better to clearly define the API using Python typing, rather than rapidly mentioning what a model should expect and return. Something like:
[...]
You can still use your own models defined as :obj:`torch.nn.Module` as long as they work the same way as the 🤗 Transformers models: they should expect `input` and `labels` and return a tuple with the loss as the first output if the `labels` are provided.
class Model(nn.Module):
def __call__(self, inputs: torch.Tensor, labels: torch.Tensor) -> Tuple[torch.Tensor]:
# do something with the inputs and labels
return loss
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would fail since it needs to return loss, logits (or just logits, if no labels are provided). Maybe just leave it as "as long as they work the same way as the 🤗 Transformers models"?
| self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps | ||
| ) | ||
|
|
||
| def setup_wandb(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay with me
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would need to specify the file's purpose here too
| return TrainOutput(self.state.global_step, tr_loss.item() / self.state.global_step) | ||
|
|
||
| def _save_training(self, model, trial, metrics=None): | ||
| def _maybe_log_save_evalute(self, tr_loss, model, trial, epoch): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(nit) Think it's better to have the if self.control.should_log:; if self.control.should_evaluated and if self.control.should_save statements directly in the code with the functions self._log , self._evaluate and self._save . I'm not a big fan of "maybe" functions
src/transformers/trainer_callback.py
Outdated
| self.train_dataloader = None | ||
| self.eval_dataloader = None | ||
|
|
||
| has_flow_callback = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could probably be simplified to if not any(isinstance(cb, DefaultFlowCallback) for cb in self.callbacks): logger.warn(...)
patrickvonplaten
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks super clean to me! I don't know enough about the Trainer for an in-detail PR though...
|
Finally! Can't wait for this PR to be merged. When saving checkpoints (including model weights as well as scheduler and optimizer states), I will be able to attach to this process and store the checkpoint in some external repository (i.e GCS / W&B artifact), right? |
|
Yes, you will be able to inject custom behavior to the saved checkpoint with the |
What does this PR do?
This PR does two things: clean up a bit the files supporting
Trainerand theTrainerclass, and add callbacks toTrainer.Callbacks
This PR introduces a new class called
TrainerCallbackthat can access the current state of the training loop and make some decisions (shown in theTrainerControlobject). This allows us to isolate the pieces of code that do log-reporting on the various ML platforms or report progress in another file and clean up the code of the maintrainmethod of theTrainer.This way, any new platform we want to integrate with for log-reporting or new behavior (like early stopping) can be implemented in a Callback while
Trainerfocuses on the main aspects of actual training, with or without mixed precision, on one or several GPUs/TPUs.As an example, integrations with TensorBoard, Wandb and ComeML are moved to the
integrationsmodule in clean callbacks, while the control flow of logs/saves/evaluations as well as progress reporting are moved to thetrainer_callbackfile.Most of the behavior stays the same as this PR essentially moves code around, but there are a few API changes:
tb_writerargument inTrainer(with full backward compatibility), people should now use theTensorBoardCallback.callbacksargument in theTrainerinit and newadd_callback,pop_callbackandremove_callbackfor theTrainer. For all of those, you can either pass an instance of a callback or a callback class.Progress bars
Here is the new progress bar behavior in console mode (checked in single and multi GPU envs, to make sure only one progress bar is displayed/logs are only printed once):
and in a jupyter notebook:
General cleanup
Not directly related to this PR, but related to the general cleanup of
Trainer, I moved a bit of stuff around: moved the utils at the start ofTrainerto a newtrainer_utils_pt. This waytrainer_utilscan be about the general training utils that work on both PyTorch and TensorFlow, and I moved the ones specific to PyTorch totrainer_utils_pt.Also in
Trainer, the code for logs, save and evaluation ended being duplicated between the end of a training step and the end of an epoch, so I put it in its private method to improve readability.