Skip to content

Comments

Avoid accessing .dataset of a DataLoader in Trainer#16451

Merged
sgugger merged 16 commits intohuggingface:mainfrom
sanderland:trainer-better-length-check
Mar 29, 2022
Merged

Avoid accessing .dataset of a DataLoader in Trainer#16451
sgugger merged 16 commits intohuggingface:mainfrom
sanderland:trainer-better-length-check

Conversation

@sanderland
Copy link
Contributor

@sanderland sanderland commented Mar 28, 2022

What does this PR do?

  • Respects get_train_dataloader and such, rather than going back and looking at .train_dataset or requiring attributes in the dataloader to be accessible directly.
    • This allows for overriding it by any object which implements the methods required by a DataLoader (__len__ and __iter__) without additional requirements.
    • The original motivation was to train on a multi-task dataloader which defers to multiple dataloaders.

Before submitting

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Mar 28, 2022

The documentation is not available anymore as the PR was closed or merged.

@sanderland sanderland marked this pull request as ready for review March 29, 2022 10:17
@sanderland
Copy link
Contributor Author

@sgugger this should be ready for review.
You were right that there were a couple of more places to change, and the logic is quite inconsistent in places. I've tried to be on the defensive side in covering cases:

  • dataloader.dataset can exist or not, and have a length or not
  • dataloader always has a len, but it can raise an exception in fairly common cases

This implementation works for my particular case, giving the same output in training+evaluation as before, but without the really painful workarounds.

I had a look at tests and they look complicated, so I will add some after getting confirmation that this is ok otherwise.

@LysandreJik LysandreJik requested a review from sgugger March 29, 2022 10:49
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your PR! It tries to do too much at the same time however. There is no reason to change the signature of the function get_train_dataloader so it should be left as is IMO. Even if it's a change we would like to implement, it should be done on it own separate PR.

Then there is a lot of code that could be refactored using the has_length function (and improving it a tiny bit).

Lastly, this PR breaks the current logging of number of examples, this should be fixed before we can merge it.

Comment on lines 1220 to 1223
len_dataloader = None
try:
len_dataloader = len(train_dataloader)
except (NameError, TypeError): # Default dataloader calls len(dataset), which may not exist
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a function has_length that would simplify the code greatly here, we can add the NameError inside it.

Copy link
Contributor Author

@sanderland sanderland Mar 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactoring as suggested, although has_length is a bit of a confusing name for __len__ does not raise an exception"

)

logger.info("***** Running training *****")
logger.info(f" Num examples = {num_examples}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code will error here since you're not defining num_examples anymore.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

num_examples was moved up inside the if statements that deal with the len/steps/size cases

num_train_epochs = math.ceil(args.num_train_epochs)
num_train_samples = len(self.train_dataset) * args.num_train_epochs
else:
# see __init__. max_steps is set when the dataset has no __len__
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this comment was incorrect, it would still be -1 which causes strange outputs. Have change it to make it explicit that this should be set.

Sander Land added 2 commits March 29, 2022 18:18
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adapting, I added a few comments on the tests.

sanderland and others added 5 commits March 29, 2022 19:19
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
@sgugger
Copy link
Collaborator

sgugger commented Mar 29, 2022

Thanks for implementing all the tweaks!

@sgugger sgugger merged commit d7c8ce5 into huggingface:main Mar 29, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants