Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trainer: fix support for non-distributed PyTorch #14971

Merged
merged 4 commits into from
Oct 3, 2022
Merged

Trainer: fix support for non-distributed PyTorch #14971

merged 4 commits into from
Oct 3, 2022

Conversation

adamjstewart
Copy link
Contributor

@adamjstewart adamjstewart commented Oct 2, 2022

What does this PR do?

Before this change, Lightning Trainers did not work with PyTorch unless it was built with distributed support. On macOS, this is rarely the case. Using a trainer would result in the following error:

>       trainer.test(model=model, datamodule=datamodule)

tests/trainers/test_byol.py:63: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
../.spack/.spack-env/view/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:862: in test
    return self._call_and_handle_interrupt(self._test_impl, model, dataloaders, ckpt_path, verbose, datamodule)
../.spack/.spack-env/view/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:650: in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
../.spack/.spack-env/view/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:909: in _test_impl
    results = self._run(model, ckpt_path=self.ckpt_path)
../.spack/.spack-env/view/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:1166: in _run
    results = self._run_stage()
../.spack/.spack-env/view/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:1249: in _run_stage
    return self._run_evaluate()
../.spack/.spack-env/view/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:1294: in _run_evaluate
    with self.profiler.profile(f"run_{self.state.stage}_evaluation"), _evaluation_context(self.accelerator):
../.spack/.spack-env/._view/bye6vsynw4rzmsits3aqqhz6ipopfcbi/lib/python3.9/contextlib.py:119: in __enter__
    return next(self.gen)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

accelerator = <pytorch_lightning.accelerators.cpu.CPUAccelerator object at 0x12ab0fa90>

    @contextmanager
    def _evaluation_context(accelerator: Accelerator) -> Generator:
        # inference mode is not supported with gloo backend (#9431),
        # and HPU & TPU accelerators.
        context_manager_class = (
            torch.inference_mode
>           if not (dist.is_initialized() and dist.get_backend() == "gloo")
            and not isinstance(accelerator, HPUAccelerator)
            and not isinstance(accelerator, TPUAccelerator)
            else torch.no_grad
        )
E       AttributeError: module 'torch.distributed' has no attribute 'is_initialized'

../.spack/.spack-env/view/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:2799: AttributeError

PyTorch's distributed API is slightly different from its CUDA API. If torch.distributed.is_available() returns False, then the rest of the API, including torch.distributed.is_initialized(), does not even exist. This PR first checks to make sure torch.distributed.is_available() is True before checking torch.distributed.is_initialized().

Does your PR introduce any breaking changes? If yes, please list them.

None

Before submitting

  • Was this discussed/approved via a GitHub issue? (not for typos and docs)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests? (not for typos and docs)
  • Did you verify new and existing tests pass locally with your changes?
  • Did you list all the breaking changes introduced by this pull request?
  • Did you update the CHANGELOG? (not for typos, docs, test updates, or minor internal changes/refactors)

PR review

Anyone in the community is welcome to review the PR.
Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:

  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified

Did you have fun?

Make sure you had fun coding 🙃

@github-actions github-actions bot added the pl Generic label for PyTorch Lightning package label Oct 2, 2022
@awaelchli awaelchli added this to the pl:1.7.x milestone Oct 2, 2022
@awaelchli awaelchli added bug Something isn't working distributed Generic distributed-related topic labels Oct 2, 2022
Copy link
Contributor

@awaelchli awaelchli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@adamjstewart Thanks! I checked and it does not look like we have other occurrences of this type. Would you mind adding an entry in the "Fixed" section of the CHANGELOG? Thanks

@awaelchli awaelchli added the community This PR is from the community label Oct 2, 2022
@awaelchli awaelchli self-assigned this Oct 2, 2022
@mergify mergify bot added the ready PRs ready to be merged label Oct 2, 2022
@otaj otaj enabled auto-merge (squash) October 3, 2022 07:54
@otaj otaj merged commit 09a8001 into Lightning-AI:master Oct 3, 2022
@adamjstewart adamjstewart deleted the fixes/no-distributed branch October 3, 2022 13:31
nicolai86 pushed a commit that referenced this pull request Oct 3, 2022
* Trainer: fix non-distributed use
* Update CHANGELOG
nicolai86 pushed a commit that referenced this pull request Oct 13, 2022
* Trainer: fix non-distributed use
* Update CHANGELOG
nicolai86 pushed a commit that referenced this pull request Oct 13, 2022
* Trainer: fix non-distributed use
* Update CHANGELOG
nicolai86 added a commit that referenced this pull request Oct 25, 2022
* use more recent lightning cloud launcher

* allow LightningApp to use custom cloud compute for flows

* feedback from adrian

* adjust other cloud tests

* update

* update

* update commens

* Update src/lightning_app/core/app.py

Co-authored-by: Sherin Thomas <[email protected]>

* Close profiler when `StopIteration` is raised (#14945)

* Find last checkpoints on restart (#14907)


Co-authored-by: Carlos Mocholí <[email protected]>

* Remove unused gcsfs dependency (#14962)

* Update hpu mixed precision link (#14974)

Signed-off-by: Jerome <[email protected]>

* Bump version of fsspec (#14975)

fsspec verbump

* Fix TPU test CI (#14926)

* Fix TPU test CI

* +x first

* Lite first to uncovert errors faster

* Fixes

* One more

* Simplify XLALauncher wrapping to avoid pickle error

* debug

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Debug commit successful. Trying local definitions

* Require tpu for mock test

* ValueError: The number of devices must be either 1 or 8, got 4 instead

* Fix mock test

* Simplify call, rely on defaults

* Skip OSError for now. Maybe upgrading will help

* Simplify launch tests, move some to lite

* Stricter typing

* RuntimeError: Accessing the XLA device before processes have spawned is not allowed.

* Revert "RuntimeError: Accessing the XLA device before processes have spawned is not allowed."

This reverts commit f65107e.

* Alternative boring solution to the reverted commit

* Fix failing test on CUDA machine

* Workarounds

* Try latest mkl

* Revert "Try latest mkl"

This reverts commit d06813a.

* Wrong exception

* xfail

* Mypy

* Comment change

* Spawn launch refactor

* Accept that we cannot lazy init now

* Fix mypy and launch test failures

* The base dockerfile already includes mkl-2022.1.0 - what if we use it?

* try a different mkl version

* Revert mkl version changes

Co-authored-by: awaelchli <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Akihiro Nitta <[email protected]>

* Trainer: fix support for non-distributed PyTorch (#14971)

* Trainer: fix non-distributed use
* Update CHANGELOG

* fixes typing errors in rich_progress.py (#14963)

* revert default cloud compute rename

* allow LightningApp to use custom cloud compute for flows

* feedback from adrian

* update

* resolve merge with master conflict

* remove preemptible

* update CHANGELOG

* add basic flow cloud compute documentation

* fix docs build

* add missing symlink

* try to fix sphinx

* another attempt for docs

* fix new test

Signed-off-by: Jerome <[email protected]>
Co-authored-by: thomas chaton <[email protected]>
Co-authored-by: Sherin Thomas <[email protected]>
Co-authored-by: Ziyad Sheebaelhamd <[email protected]>
Co-authored-by: otaj <[email protected]>
Co-authored-by: Carlos Mocholí <[email protected]>
Co-authored-by: Jerome Anand <[email protected]>
Co-authored-by: awaelchli <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Akihiro Nitta <[email protected]>
Co-authored-by: Adam J. Stewart <[email protected]>
Co-authored-by: DP <[email protected]>
carmocca added a commit that referenced this pull request Oct 25, 2022
* use more recent lightning cloud launcher

* allow LightningApp to use custom cloud compute for flows

* feedback from adrian

* adjust other cloud tests

* update

* update

* update commens

* Update src/lightning_app/core/app.py

Co-authored-by: Sherin Thomas <[email protected]>

* Close profiler when `StopIteration` is raised (#14945)

* Find last checkpoints on restart (#14907)


Co-authored-by: Carlos Mocholí <[email protected]>

* Remove unused gcsfs dependency (#14962)

* Update hpu mixed precision link (#14974)

Signed-off-by: Jerome <[email protected]>

* Bump version of fsspec (#14975)

fsspec verbump

* Fix TPU test CI (#14926)

* Fix TPU test CI

* +x first

* Lite first to uncovert errors faster

* Fixes

* One more

* Simplify XLALauncher wrapping to avoid pickle error

* debug

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Debug commit successful. Trying local definitions

* Require tpu for mock test

* ValueError: The number of devices must be either 1 or 8, got 4 instead

* Fix mock test

* Simplify call, rely on defaults

* Skip OSError for now. Maybe upgrading will help

* Simplify launch tests, move some to lite

* Stricter typing

* RuntimeError: Accessing the XLA device before processes have spawned is not allowed.

* Revert "RuntimeError: Accessing the XLA device before processes have spawned is not allowed."

This reverts commit f65107e.

* Alternative boring solution to the reverted commit

* Fix failing test on CUDA machine

* Workarounds

* Try latest mkl

* Revert "Try latest mkl"

This reverts commit d06813a.

* Wrong exception

* xfail

* Mypy

* Comment change

* Spawn launch refactor

* Accept that we cannot lazy init now

* Fix mypy and launch test failures

* The base dockerfile already includes mkl-2022.1.0 - what if we use it?

* try a different mkl version

* Revert mkl version changes

Co-authored-by: awaelchli <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Akihiro Nitta <[email protected]>

* Trainer: fix support for non-distributed PyTorch (#14971)

* Trainer: fix non-distributed use
* Update CHANGELOG

* fixes typing errors in rich_progress.py (#14963)

* revert default cloud compute rename

* allow LightningApp to use custom cloud compute for flows

* feedback from adrian

* update

* resolve merge with master conflict

* remove preemptible

* update CHANGELOG

* add basic flow cloud compute documentation

* fix docs build

* add missing symlink

* try to fix sphinx

* another attempt for docs

* fix new test

Signed-off-by: Jerome <[email protected]>
Co-authored-by: thomas chaton <[email protected]>
Co-authored-by: Sherin Thomas <[email protected]>
Co-authored-by: Ziyad Sheebaelhamd <[email protected]>
Co-authored-by: otaj <[email protected]>
Co-authored-by: Carlos Mocholí <[email protected]>
Co-authored-by: Jerome Anand <[email protected]>
Co-authored-by: awaelchli <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Akihiro Nitta <[email protected]>
Co-authored-by: Adam J. Stewart <[email protected]>
Co-authored-by: DP <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working community This PR is from the community distributed Generic distributed-related topic pl Generic label for PyTorch Lightning package ready PRs ready to be merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants