Skip to content

Conversation

@muellerzr
Copy link
Contributor

What does this PR do?

This PR is an alternative to #2895 which uses composition in the end to be a pinch less "magical". Since we're also using composition, if we eventually want any "base iterable" type of loader to be compatible/use these methods, as long as the underlying assumption is they behave like DataLoader's, this logic could scale.

Fixes #2859

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@SunMarc @BenjaminBossan @byi8220

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@byi8220
Copy link
Contributor

byi8220 commented Aug 6, 2024

Thanks! I took a quick look and left a few comments.

The only reason I didn't immediately go for pure composition was I was worried breaking something if class inheritance structure is changed. I believe accelerate.py has some explicit checks for isinstance(obj, DataLoader), and other repos might rely on this class structure as well.

If that's not an issue then either I could refactor my PR to be closer to yours, or we could iterate on this PR instead. It's up to you.

Also, I did make some unit tests in PR #2895, in case you could reuse them and save you time from writing your own tests.

@BenjaminBossan
Copy link
Member

I haven't followed the whole history of this feature. @muellerzr is the issue mentioned by @byi8220 addressed that this PR would change the inheritance structure of data loaders and could potentially break existing code that checks isinstance? I tried to search for methods to dynamically modify the isinstance result from the instance side of things, but there does not appear to be any suitable solution in Python. We could override __class__ but that seems exceedingly dangerous. From the side of the class that is checked against, we could override __instancecheck__, but IIUC that would not help here.

@muellerzr
Copy link
Contributor Author

@byi8220 all of those checks are whether we should preprocess someones existing dataloader using prepare(), so it's fine. I've added you as a coauthor here so you can get commit credit, I'll be bringing in your tests today + finishing up the rest of this PR so we should be fine

@muellerzr
Copy link
Contributor Author

If there's some particular isinstance() checks you think would fail from this downstream, let me know. From the trainer integration side should be fine (failing error here is due to a small tweak needed since we don't implicitly inherit dataloader rn)

@byi8220
Copy link
Contributor

byi8220 commented Aug 6, 2024

@BenjaminBossan Yes. This PR overall LGTM, and the main issues I thought of were:

  1. The inheritance structure change breaking checks for isinstance as mentioned, although it seems fine.
  2. An off by one bug, since dl_shard._dataloader's iterator is one ahead of what dl_shard is yielding.

We could override __class__ but that seems exceedingly dangerous.

Yeah, that was the original idea in the previous PR, which is pretty hacky.

@muellerzr

all of those checks are whether we should preprocess someones existing dataloader using prepare(),

Makes sense. In this case, a DataLoaderShard or DataLoaderAdapter would just pass through prepare()?

If there's some particular isinstance() checks you think would fail from this downstream, let me know.

There was nothing in particular I could think of, breaking class structure just felt fishy to me.



class DataLoaderShard(DataLoader, DataLoaderStateMixin):
class DataLoaderShard(DataLoaderStateMixin):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we 100% okay with changing this? Composition is the better way to do this, but the only issues I could see are existing checks for isinstance(dl, DataLoader) breaking.

Copy link

@andrewkho andrewkho Aug 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would something like this work? StatefulDataLoader is a subclass of torch.utils.data.DataLoader

Suggested change
class DataLoaderShard(DataLoaderStateMixin):
_BaseClass = torch.utils.data.DataLoader
if is_torchdata_available():
_BaseClass = torchdata.stateful_dataloader.StatefulDataLoader
class DataLoaderShard(_BaseClass, DataLoaderStateMixin):

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, because this assumes users are blanket using the stateful_dataloader if torchdata is available. We do not want this, it must be configurable by the user

Copy link

@andrewkho andrewkho Aug 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do users call the DataLoaderShard constructor directly? If it was a factory method it'd be possible to switch out the class based on config but guessing that's not an option, something like StatefulDataLoaderShard and users create that instead


self.set_epoch(self.iteration)
dataloader_iter = super().__iter__()
dataloader_iter = iter(self._dataloader)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is an issue which prevents you from simply just delegating state_dict and load_state_dict.

The dataloader_iter is actually one ahead of what we are going to yield. dataloader_iter is always pointing to next_batch, not current_batch.

I got around it in my PR by using a _save_state_dict function to hold the previous value of state_dict. I also wrote some unit tests to catch this issue.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, this is because there's a pre-fetch to send to device

Comment on lines 465 to 471
def __getattr__(self, name):
# Delegate attribute access to the internal instance
return getattr(self._dataloader, name)

def __len__(self):
return len(self._dataloader)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a lot cleaner than my approach, but it still feels a bit magical. Are we sure len is the only builtin that needs to be supported?

@muellerzr
Copy link
Contributor Author

Given the fact that the Trainer requires a raw torch.utils.data.DataLoader to be inherited at some point, this still doesn't work. For now I'm not a fan of any solution with these, and the user should just create their own stateful dataloader.

Going back to the drawing board with the torch team, as no solutions feel right.

@muellerzr muellerzr marked this pull request as draft August 6, 2024 14:13
@byi8220
Copy link
Contributor

byi8220 commented Aug 6, 2024

Given the fact that the Trainer requires a raw torch.utils.data.DataLoader to be inherited at some point,

To confirm, this is blocked on a hard inheritance requirement? And all current solutions feel extremely ugly?

Going back to the drawing board with the torch team, as no solutions feel right.

Hm, might be overkill but would it be sufficient to ask the torch team to separate DataLoader to an interface and a concrete, then have the accelerate derivatives implement the interface? Feels a bit bloated, but feels like a canonical solution.

@muellerzr
Copy link
Contributor Author

@byi8220 @BenjaminBossan and I were thinking on similar lines

@muellerzr muellerzr closed this Aug 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Improve skip_first_batches method to efficiently support IterableDataset and StatefulDataloader

6 participants