Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fault tolerance for the StreamingDataset 1/n #19049

Merged
merged 13 commits into from
Nov 22, 2023

Conversation

tchaton
Copy link
Contributor

@tchaton tchaton commented Nov 22, 2023

What does this PR do?

This PR adds support for fault tolerance for the StreamingDataset. This enables to use it seamlessly with Fabric.save as follows:

dataset = StreamingDataset(...)

state = {...., "model": model, "dataset": dataset}

...

fabric.save(state)

Fixes #<issue_number>

Before submitting
  • Was this discussed/agreed via a GitHub issue? (not for typos and docs)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests? (not for typos and docs)
  • Did you verify new and existing tests pass locally with your changes?
  • Did you list all the breaking changes introduced by this pull request?
  • Did you update the CHANGELOG? (not for typos, docs, test updates, or minor internal changes/refactors)

PR review

Anyone in the community is welcome to review the PR.
Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:

Reviewer checklist
  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified

📚 Documentation preview 📚: https://pytorch-lightning--19049.org.readthedocs.build/en/19049/

cc @Borda

@github-actions github-actions bot added the data (external) litdata package label Nov 22, 2023
Copy link
Contributor

github-actions bot commented Nov 22, 2023

⚡ Required checks status: All passing 🟢

Groups summary

🟢 lightning_data: CPU workflow
Check ID Status
data-cpu (macOS-11, lightning, 3.10, 2.1) success
data-cpu (ubuntu-20.04, lightning, 3.10, 2.1) success
data-cpu (windows-2022, lightning, 3.10, 2.1) success

These checks are required after the changes to src/lightning/data/streaming/__init__.py, src/lightning/data/streaming/cache.py, src/lightning/data/streaming/constants.py, src/lightning/data/streaming/dataset.py, src/lightning/data/streaming/item_loader.py, src/lightning/data/streaming/shuffle.py, tests/tests_data/streaming/test_data_processor.py, tests/tests_data/streaming/test_dataset.py.

🟢 mypy
Check ID Status
mypy success

These checks are required after the changes to src/lightning/data/streaming/__init__.py, src/lightning/data/streaming/cache.py, src/lightning/data/streaming/constants.py, src/lightning/data/streaming/dataset.py, src/lightning/data/streaming/item_loader.py, src/lightning/data/streaming/shuffle.py.

🟢 install
Check ID Status
install-pkg (ubuntu-22.04, app, 3.8) success
install-pkg (ubuntu-22.04, app, 3.11) success
install-pkg (ubuntu-22.04, fabric, 3.8) success
install-pkg (ubuntu-22.04, fabric, 3.11) success
install-pkg (ubuntu-22.04, pytorch, 3.8) success
install-pkg (ubuntu-22.04, pytorch, 3.11) success
install-pkg (ubuntu-22.04, lightning, 3.8) success
install-pkg (ubuntu-22.04, lightning, 3.11) success
install-pkg (ubuntu-22.04, notset, 3.8) success
install-pkg (ubuntu-22.04, notset, 3.11) success
install-pkg (macOS-12, app, 3.8) success
install-pkg (macOS-12, app, 3.11) success
install-pkg (macOS-12, fabric, 3.8) success
install-pkg (macOS-12, fabric, 3.11) success
install-pkg (macOS-12, pytorch, 3.8) success
install-pkg (macOS-12, pytorch, 3.11) success
install-pkg (macOS-12, lightning, 3.8) success
install-pkg (macOS-12, lightning, 3.11) success
install-pkg (macOS-12, notset, 3.8) success
install-pkg (macOS-12, notset, 3.11) success
install-pkg (windows-2022, app, 3.8) success
install-pkg (windows-2022, app, 3.11) success
install-pkg (windows-2022, fabric, 3.8) success
install-pkg (windows-2022, fabric, 3.11) success
install-pkg (windows-2022, pytorch, 3.8) success
install-pkg (windows-2022, pytorch, 3.11) success
install-pkg (windows-2022, lightning, 3.8) success
install-pkg (windows-2022, lightning, 3.11) success
install-pkg (windows-2022, notset, 3.8) success
install-pkg (windows-2022, notset, 3.11) success

These checks are required after the changes to src/lightning/data/streaming/__init__.py, src/lightning/data/streaming/cache.py, src/lightning/data/streaming/constants.py, src/lightning/data/streaming/dataset.py, src/lightning/data/streaming/item_loader.py, src/lightning/data/streaming/shuffle.py.


Thank you for your contribution! 💜

Note
This comment is automatically generated and updates for 60 minutes every 180 seconds. If you have any other questions, contact carmocca for help.

@github-actions github-actions bot added the pl Generic label for PyTorch Lightning package label Nov 22, 2023
Copy link

codecov bot commented Nov 22, 2023

Codecov Report

Merging #19049 (38c3c63) into master (bc16580) will decrease coverage by 34%.
The diff coverage is 0%.

Additional details and impacted files
@@            Coverage Diff             @@
##           master   #19049      +/-   ##
==========================================
- Coverage      83%      49%     -34%     
==========================================
  Files         443      435       -8     
  Lines       36185    36120      -65     
==========================================
- Hits        30114    17614   -12500     
- Misses       6071    18506   +12435     

@tchaton tchaton changed the title Add fault tolerance StreamingDataset 1/n Add fault tolerance for the StreamingDataset 1/n Nov 22, 2023
@github-actions github-actions bot removed the pl Generic label for PyTorch Lightning package label Nov 22, 2023
Copy link
Collaborator

@lantiga lantiga left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! A couple of comments

src/lightning/data/streaming/dataset.py Outdated Show resolved Hide resolved
tests/tests_data/streaming/test_dataset.py Show resolved Hide resolved
@tchaton tchaton merged commit 1073276 into master Nov 22, 2023
53 checks passed
@tchaton tchaton deleted the resumable_streaming_dataset branch November 22, 2023 17:22
@mergify mergify bot added the ready PRs ready to be merged label Nov 22, 2023
src/lightning/data/streaming/cache.py Show resolved Hide resolved
@@ -102,6 +106,20 @@ def filled(self) -> bool:
self._is_done = os.path.exists(os.path.join(self._cache_dir, _INDEX_FILENAME))
return self._is_done

@property
def checkpoint_dir(self) -> str:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It shouldn't be necessary to duplicate the code in both of these.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure.

src/lightning/data/streaming/dataset.py Show resolved Hide resolved
self.current_indexes = current_indexes[state["index"] :]

# Bump the chunk_index
self.chunk_index += 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why +1 the index? We're reloading it in the line above. If the chunk wasn't complete, we would now miss the remainder?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  state = self._state_dict[str(self.cache.rank)]

  # re-generate indexes
  interval = self.worker_intervals[self.chunk_index]
  current_indexes = np.arange(interval[0], interval[1])
  current_indexes = self.shuffler(current_indexes, self.current_epoch, self.chunk_index)
  self.current_indexes = current_indexes[state["index"] :]

  # Bump the chunk_index
  self.chunk_index += 1

No, it won't. The chunk_index is bumped only once the current_indexes are re-created.

I will clean this up in another PR.


# 3. Move the file to avoid corrupted read from the main thread.
now = datetime.now().strftime(_TIME_FORMAT)
checkpoint_path = os.path.join(self.cache.checkpoint_rank_dir, f"checkpoint-{now}.json")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is the file time stamped? This will create so many files. IMO we should overwrite the file, because only the current state matters. And since every worker only saves to their dedicated folder, it is safe.

return data

def checkpoint(self, chunk_index: int) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't be public. The user wouldn't be able to use this effectively, because they can only call it from the main process. And from the main process it never makes sense.

I suggest to 1) make it private 2) raise an error if called in main process

state_dict.update(**state)
node_ranks.append(node_rank)
else:
raise NotImplementedError("The `state_dict` should be called on the main thread.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But they aren't threads, they are proper processes. And we should raise immediately at the beginning, this would eliminate the entire if-else block, making the code much more readable

Comment on lines +321 to +333
# TODO: Move this to fabric.
num_devices = torch.cuda.device_count() or 1
node_ranks = []
for index in range(self.distributed_env.world_size):
node_rank = index // num_devices
if node_rank in node_ranks:
continue
state = {}
obj = [_state_dict]
torch.distributed.broadcast_object_list(obj, index, group=_group.WORLD)
state = obj[0]
state_dict.update(**state)
node_ranks.append(node_rank)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should put this in a function. Way easier to unit test!

if not os.path.exists(self.cache.checkpoint_dir):
return state_dict

# 2. Iterate through the workers and read the latest checkpoint
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This step wouldn't be necessary, see comment above

assert self.random_state
return self.random_state.permutation(array).tolist()
def __call__(self, array: np.ndarray, current_epoch: int, chunk_index: int) -> List[int]:
return np.random.RandomState(seed=self.seed + current_epoch + chunk_index).permutation(array).tolist()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is problematic, because the seed will not be unique. For example 5 + 4 = 4 + 5 = 9. We should at least multiple the current epoch by number of chunks.

Borda pushed a commit that referenced this pull request Dec 19, 2023
lantiga pushed a commit that referenced this pull request Dec 20, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
data (external) litdata package ready PRs ready to be merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants