Skip to content

Conversation

@sgugger
Copy link
Collaborator

@sgugger sgugger commented Sep 29, 2020

What does this PR do?

This PR tries to limit the access to model.config in Trainer to the minimum so that it works with regular PyTorch modules (as long as they accept dict inputs and return loss first like our models). The most challenging part was the storing/restoring of the total_flos, which I moved to the newly created TrainerState. It should work as before and be saved along the rest of the training state.

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

assert not getattr(
self.model.config, "output_hidden_states", False
), "The prediction loop does not work with `output_hidden_states=True`."

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn’t we want to put these lines inside an if statement? The prediction loop still doesn’t work with these outputs right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope it does now since the functions that detach/concat etc. all work on nested list/tuples of tensors :-)

Copy link
Contributor

@TevenLeScao TevenLeScao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM ! Much cleaner this way, thanks!

@sgugger sgugger merged commit fdccf82 into master Sep 30, 2020
@sgugger sgugger deleted the trainer_dont_assume_config branch September 30, 2020 13:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants