Skip to content

Conversation

@sgugger
Copy link
Collaborator

@sgugger sgugger commented Oct 5, 2020

What does this PR do?

This PR does two things: clean up a bit the files supporting Trainer and the Trainer class, and add callbacks to Trainer.

Callbacks

This PR introduces a new class called TrainerCallback that can access the current state of the training loop and make some decisions (shown in the TrainerControl object). This allows us to isolate the pieces of code that do log-reporting on the various ML platforms or report progress in another file and clean up the code of the main train method of the Trainer.

This way, any new platform we want to integrate with for log-reporting or new behavior (like early stopping) can be implemented in a Callback while Trainer focuses on the main aspects of actual training, with or without mixed precision, on one or several GPUs/TPUs.

As an example, integrations with TensorBoard, Wandb and ComeML are moved to the integrations module in clean callbacks, while the control flow of logs/saves/evaluations as well as progress reporting are moved to the trainer_callback file.

Most of the behavior stays the same as this PR essentially moves code around, but there are a few API changes:

  • deprecating the tb_writer argument in Trainer (with full backward compatibility), people should now use the TensorBoardCallback.
  • a new callbacks argument in the Trainer init and new add_callback, pop_callback and remove_callback for the Trainer. For all of those, you can either pass an instance of a callback or a callback class.
  • Cleaned up the progress bars a bit with only one main progress bar over all the steps we will do for training and evaluation bars that disappear after being done

Progress bars

Here is the new progress bar behavior in console mode (checked in single and multi GPU envs, to make sure only one progress bar is displayed/logs are only printed once):

and in a jupyter notebook:

General cleanup

Not directly related to this PR, but related to the general cleanup of Trainer, I moved a bit of stuff around: moved the utils at the start of Trainer to a new trainer_utils_pt. This way trainer_utils can be about the general training utils that work on both PyTorch and TensorFlow, and I moved the ones specific to PyTorch to trainer_utils_pt.

Also in Trainer, the code for logs, save and evaluation ended being duplicated between the end of a training step and the end of an epoch, so I put it in its private method to improve readability.

self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps
)

def setup_wandb(self):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing this method is the only real breaking change, but I doubt we have a lot of users that subclassed Trainer to implement a custom setup_wandb. For those users, the new way is to subclass WandbCallback and override the setup method.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay with me

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two breaking changes: setup_comet() was removed as well. But that is okay with me too. Just wanted to point that out if this is listed in a set of "breaking changes" in the release notes.

self._total_flos = self.state.total_flos
logging_loss_scalar = 0.0
model.zero_grad()
disable_tqdm = self.args.disable_tqdm or not self.is_local_process_zero()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not commenting on all of those, but all references to progress bars disappear as it's all dealt with in ProgressCalback.


return TrainOutput(self.state.global_step, tr_loss.item() / self.state.global_step)

def _save_training(self, model, trial, metrics=None):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed private _save_training to _save_checkpoint as I think it's better.

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very cool, I think it cleans up the Trainer class a lot. LGTM!

As discussed with @lhoestq and @julien-c, it would be nice to add in the documentation that these are Keas-style callbacks. When we hear callbacks, we might understand function callbacks, which are very different in practice.

EDIT: After discussing it with Sylvain, the point made above doesn't need to be applied as it isn't as ambiguous as I thought it would be. Citing Sylvain: it's callback as in events, similarly to what is done in VisualBasic or Delphi: on_click/on_touch etc.

"""

def __init__(self):
assert _has_comet, "CometCallback requires wandb to be installed. Run `pip install comet-ml`."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think comet requires wandb to be installed :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, good catch!

Comment on lines 13 to 15
# See the License for the specific language governing permissions and
# limitations under the License.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should add the description of the file here

Comment on lines 145 to 148
:class:`~transformers.Trainer` is optimized to work with the :class:`~transformers.PreTrainedModel`
provided by the library. You can still use your own models defined as :obj:`torch.nn.Module` as long as
they work the same way as the 🤗 Transformers models (expect input and labels as arguments and return
a tuple with the loss first if labels are provided).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imo it would be better to clearly define the API using Python typing, rather than rapidly mentioning what a model should expect and return. Something like:

[...]
You can still use your own models defined as :obj:`torch.nn.Module` as long as they work the same way as the 🤗 Transformers models: they should expect `input` and `labels` and return a tuple with the loss as the first output if the `labels` are provided.

class Model(nn.Module):
    def __call__(self, inputs: torch.Tensor, labels: torch.Tensor) -> Tuple[torch.Tensor]:
        # do something with the inputs and labels
        return loss

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would fail since it needs to return loss, logits (or just logits, if no labels are provided). Maybe just leave it as "as long as they work the same way as the 🤗 Transformers models"?

self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps
)

def setup_wandb(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay with me

# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would need to specify the file's purpose here too

return TrainOutput(self.state.global_step, tr_loss.item() / self.state.global_step)

def _save_training(self, model, trial, metrics=None):
def _maybe_log_save_evalute(self, tr_loss, model, trial, epoch):
Copy link
Contributor

@patrickvonplaten patrickvonplaten Oct 6, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) Think it's better to have the if self.control.should_log:; if self.control.should_evaluated and if self.control.should_save statements directly in the code with the functions self._log , self._evaluate and self._save . I'm not a big fan of "maybe" functions

self.train_dataloader = None
self.eval_dataloader = None

has_flow_callback = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could probably be simplified to if not any(isinstance(cb, DefaultFlowCallback) for cb in self.callbacks): logger.warn(...)

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks super clean to me! I don't know enough about the Trainer for an in-detail PR though...

@marrrcin
Copy link
Contributor

marrrcin commented Oct 7, 2020

Finally! Can't wait for this PR to be merged.
I've briefly looked at the code and from my understanding it should support this case, but correct me if I'm wrong:

When saving checkpoints (including model weights as well as scheduler and optimizer states), I will be able to attach to this process and store the checkpoint in some external repository (i.e GCS / W&B artifact),

right?

@sgugger
Copy link
Collaborator Author

sgugger commented Oct 7, 2020

Yes, you will be able to inject custom behavior to the saved checkpoint with the on_save event.

@sgugger sgugger merged commit 08ba4b4 into master Oct 7, 2020
@sgugger sgugger deleted the trainer_callbacks branch October 7, 2020 14:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants