-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add fault tolerance for the StreamingDataset 1/n #19049
Conversation
⚡ Required checks status: All passing 🟢Groups summary🟢 lightning_data: CPU workflow
These checks are required after the changes to 🟢 mypy
These checks are required after the changes to 🟢 installThese checks are required after the changes to Thank you for your contribution! 💜
|
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## master #19049 +/- ##
==========================================
- Coverage 83% 49% -34%
==========================================
Files 443 435 -8
Lines 36185 36120 -65
==========================================
- Hits 30114 17614 -12500
- Misses 6071 18506 +12435 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! A couple of comments
@@ -102,6 +106,20 @@ def filled(self) -> bool: | |||
self._is_done = os.path.exists(os.path.join(self._cache_dir, _INDEX_FILENAME)) | |||
return self._is_done | |||
|
|||
@property | |||
def checkpoint_dir(self) -> str: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It shouldn't be necessary to duplicate the code in both of these.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure.
self.current_indexes = current_indexes[state["index"] :] | ||
|
||
# Bump the chunk_index | ||
self.chunk_index += 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why +1 the index? We're reloading it in the line above. If the chunk wasn't complete, we would now miss the remainder?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
state = self._state_dict[str(self.cache.rank)]
# re-generate indexes
interval = self.worker_intervals[self.chunk_index]
current_indexes = np.arange(interval[0], interval[1])
current_indexes = self.shuffler(current_indexes, self.current_epoch, self.chunk_index)
self.current_indexes = current_indexes[state["index"] :]
# Bump the chunk_index
self.chunk_index += 1
No, it won't. The chunk_index
is bumped only once the current_indexes
are re-created.
I will clean this up in another PR.
|
||
# 3. Move the file to avoid corrupted read from the main thread. | ||
now = datetime.now().strftime(_TIME_FORMAT) | ||
checkpoint_path = os.path.join(self.cache.checkpoint_rank_dir, f"checkpoint-{now}.json") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is the file time stamped? This will create so many files. IMO we should overwrite the file, because only the current state matters. And since every worker only saves to their dedicated folder, it is safe.
return data | ||
|
||
def checkpoint(self, chunk_index: int) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This shouldn't be public. The user wouldn't be able to use this effectively, because they can only call it from the main process. And from the main process it never makes sense.
I suggest to 1) make it private 2) raise an error if called in main process
state_dict.update(**state) | ||
node_ranks.append(node_rank) | ||
else: | ||
raise NotImplementedError("The `state_dict` should be called on the main thread.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But they aren't threads, they are proper processes. And we should raise immediately at the beginning, this would eliminate the entire if-else block, making the code much more readable
# TODO: Move this to fabric. | ||
num_devices = torch.cuda.device_count() or 1 | ||
node_ranks = [] | ||
for index in range(self.distributed_env.world_size): | ||
node_rank = index // num_devices | ||
if node_rank in node_ranks: | ||
continue | ||
state = {} | ||
obj = [_state_dict] | ||
torch.distributed.broadcast_object_list(obj, index, group=_group.WORLD) | ||
state = obj[0] | ||
state_dict.update(**state) | ||
node_ranks.append(node_rank) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should put this in a function. Way easier to unit test!
if not os.path.exists(self.cache.checkpoint_dir): | ||
return state_dict | ||
|
||
# 2. Iterate through the workers and read the latest checkpoint |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This step wouldn't be necessary, see comment above
assert self.random_state | ||
return self.random_state.permutation(array).tolist() | ||
def __call__(self, array: np.ndarray, current_epoch: int, chunk_index: int) -> List[int]: | ||
return np.random.RandomState(seed=self.seed + current_epoch + chunk_index).permutation(array).tolist() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is problematic, because the seed will not be unique. For example 5 + 4 = 4 + 5 = 9. We should at least multiple the current epoch by number of chunks.
Co-authored-by: thomas <[email protected]> (cherry picked from commit 1073276)
Co-authored-by: thomas <[email protected]> (cherry picked from commit 1073276)
What does this PR do?
This PR adds support for fault tolerance for the StreamingDataset. This enables to use it seamlessly with Fabric.save as follows:
Fixes #<issue_number>
Before submitting
PR review
Anyone in the community is welcome to review the PR.
Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:
Reviewer checklist
📚 Documentation preview 📚: https://pytorch-lightning--19049.org.readthedocs.build/en/19049/
cc @Borda