-
Notifications
You must be signed in to change notification settings - Fork 223
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Lhotse Shar tutorial notebook #1006
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
@@ -0,0 +1,882 @@ | |||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -0,0 +1,882 @@ | |||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Line #3. shards = cuts_train.to_shar(data_dir, fields={"recording": "wav"}, shard_size=1000)
May be helpful to explain what other values may be passed to the fields
dict.
Also, what is the granularity of sharding? Is each recording treated as a unit?
Reply via ReviewNB
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, each cut is treated as a unit, so shard_size=1000 means 1000 cuts per shard. I will add this explanation to the tutorial and also provide link to the documentation that describes the fields
dict in more detail (I noticed that Lhotse Shar docs were not linked in readthedocs and I'm also fixing it now).
@@ -0,0 +1,882 @@ | |||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you mean by "keep reshuffling as the full epoch is reached"? Does this mean each worker's copy is shuffled (with a different RNG) at the start of every epoch of training?
Also, what would happen if we do not use the iterable dataset wrapper? Would all the workers generate the same batches?
Reply via ReviewNB
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you mean by "keep reshuffling as the full epoch is reached"? Does this mean each worker's copy is shuffled (with a different RNG) at the start of every epoch of training?
It roughly means the following: given a dataset of three shards [A, B, C], a single node, two dataloader workers W1 and W2, and global random seed=0, the training dataloading might look like the following (assuming stateful_shuffle=True
):
Epoch 0:
W1 uses RNG with seed (global=0 + worker-id=1 + 1000*rank=0) + epoch=0 = 1 and has order: [B, A, C]
W2 uses RNG with seed (global=0 + worker-id=2 + 1000*rank=0) + epoch=0 = 2 and has order: [C, B, A]
Epoch 1:
W1 uses RNG with seed (global=0 + worker-id=1 + 1000*rank=0) + epoch=1 = 2 and has order: [C, B, A]
W2 uses RNG with seed (global=0 + worker-id=2 + 1000*rank=0) + epoch=1 = 3 and has order: [A, B, C]
... and so on.
Note that since .repeat()
makes CutSet infinite, the dataloader will never stop yielding data, so you won't easily know what is the current epoch -- it's best to count steps, although if you really need to know the epoch, Shar attaches a custom field cut.shar_epoch
to each cut that you can read out to understand what's the epoch. You'll also generally observe that each shar_epoch contains world_size * num_workers actual epochs in this setup.
BTW After writing this I realized that I need to check what kind of IDs are given to workers by PyTorch so we can avoid seeing too much of the same order of data (randomized augmentations probably help with that though and it should matter less with large datasets).
Also, what would happen if we do not use the iterable dataset wrapper? Would all the workers generate the same batches?
Then you'd end up with data I/O happening in the training loop process (since with WebDataset and Shar the I/O happens upon iterating CutSet) and binary blobs being transferred to DataLoader worker process. But they wouldn't duplicate the data. It'd be great to emit a warning to the user if that happens but I don't really have an idea how to detect that.
Great tutorial! I was able to get a nice overview of Lhotse Shar. |
Finally found some time to write it down a bit. It doesn't show every possible option, but should be enough to get started. I think in general this workflow may have been a bit simpler if 3 years ago I knew that Lhotse would go in this direction :) maybe some simplifications (and breaking changes) can be made in the future, but I don't plan them right now.