Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve manual optimization API #5771

Merged
merged 355 commits into from
Feb 16, 2021
Merged
Changes from 1 commit
Commits
Show all changes
355 commits
Select commit Hold shift + click to select a range
79803f6
Fix import issue, attempting to fix tests
Jan 12, 2021
a7c0d8f
Fix initial test
Jan 12, 2021
02df0ad
Reflect hook logic from master, should wrap model after move to device
Jan 14, 2021
d0ebcba
Optional state consolidation, since master has optimizers not wrapped
justusschock Jan 22, 2021
319c3e8
change attribute for instance test
justusschock Jan 22, 2021
a34cd15
reset optimizers
justusschock Jan 22, 2021
c95b06a
legacy
Borda Jan 22, 2021
9ff0c64
imports in accel
Borda Jan 22, 2021
67d4e47
legacy2
Borda Jan 22, 2021
577b00d
trainer imports
Borda Jan 22, 2021
aa4858b
fix import errors after rebase
awaelchli Jan 25, 2021
f81a44f
move hook to new setup location
awaelchli Jan 25, 2021
a285665
provide unwrapping logic
awaelchli Jan 25, 2021
bf78d70
fix trainer callback system
awaelchli Jan 25, 2021
34947cf
added ddp2 implementation
awaelchli Jan 25, 2021
49bec53
fix imports .legacy
Borda Jan 25, 2021
ba1c986
move plugins
Borda Jan 25, 2021
45dfbb7
restore legacy
Borda Jan 25, 2021
9b7326a
drop test.py from root
Borda Jan 25, 2021
96bc05d
add tpu accelerator and plugins
justusschock Jan 26, 2021
c5994e5
Merge branch 'release/1.2-dev' into accelerator-refactor-sharted-4
awaelchli Jan 30, 2021
9e46624
fixes
awaelchli Jan 30, 2021
22d2ae8
Merge branch 'release/1.2-dev' into accelerator-refactor-sharted-4
awaelchli Jan 30, 2021
901d392
Merge branch 'release/1.2-dev' into accelerator-refactor-sharted-4
awaelchli Jan 31, 2021
e174b8d
fix lightning optimizer merge
awaelchli Jan 31, 2021
98660de
reset bugreportmodel
awaelchli Jan 31, 2021
4d95b6c
unwrapping
awaelchli Jan 31, 2021
b69d013
step routing forward
awaelchli Jan 31, 2021
cb6676d
model access
awaelchli Jan 31, 2021
a33d27f
unwrap
awaelchli Jan 31, 2021
f7486e2
opt
awaelchli Jan 31, 2021
117f16d
Merge branch 'release/1.2-dev' into accelerator-refactor-sharted-4
awaelchli Jan 31, 2021
3792b72
integrate distrib_type
awaelchli Jan 31, 2021
ef85b81
sync changes
awaelchli Jan 31, 2021
9d9a940
sync
awaelchli Feb 1, 2021
f017a39
Merge branch 'release/1.2-dev' into accelerator-refactor-sharted-4
awaelchli Feb 1, 2021
a190a56
fixes
awaelchli Feb 1, 2021
73bb607
add forgotten generators
awaelchli Feb 1, 2021
c8c74f3
Merge branch 'release/1.2-dev' into accelerator-refactor-sharted-4
awaelchli Feb 1, 2021
ae71997
add missing logic
awaelchli Feb 1, 2021
d89847b
Merge branch 'release/1.2-dev' into accelerator-refactor-sharted-4
awaelchli Feb 1, 2021
0e686c3
update
awaelchli Feb 1, 2021
d6a43ea
import
awaelchli Feb 1, 2021
ceb8f75
missed imports
awaelchli Feb 1, 2021
fbb7c20
import fixes
awaelchli Feb 1, 2021
b610999
isort
awaelchli Feb 1, 2021
9b79924
mv f
awaelchli Feb 1, 2021
9afe54d
changelog
awaelchli Feb 1, 2021
3b63e82
Merge branch 'release/1.2-dev' into ref/update-plugins
awaelchli Feb 1, 2021
ca8cb68
format
awaelchli Feb 1, 2021
0633745
move helper to parallel plugin
awaelchli Feb 1, 2021
a622e0b
d
awaelchli Feb 1, 2021
18c682f
Merge branch 'ref/update-plugins' into accelerator-refactor-sharted-4
awaelchli Feb 1, 2021
f275803
add world size
awaelchli Feb 1, 2021
4ae008b
clean up
awaelchli Feb 1, 2021
3b3918b
Merge branch 'release/1.2-dev' into accelerator-refactor-sharted-4
awaelchli Feb 1, 2021
d4c6308
duplicate
awaelchli Feb 1, 2021
7eef4a0
Merge branch 'release/1.2-dev' into accelerator-refactor-sharted-4
awaelchli Feb 2, 2021
9949164
activate ddp_sharded and tpu
awaelchli Feb 2, 2021
6d47357
set nvidia flags
awaelchli Feb 2, 2021
a6864ec
remove unused colab var
awaelchli Feb 2, 2021
b4b9724
use_tpu <-> on_tpu attrs
awaelchli Feb 2, 2021
81001e3
make some ddp_cpu and clusterplugin tests pass
awaelchli Feb 2, 2021
cea000d
Ref/accelerator connector (#5742)
justusschock Feb 2, 2021
933e2a1
plugins
awaelchli Feb 2, 2021
ad451d8
manual optimization
justusschock Feb 2, 2021
a30a3cf
update optimizer routing
justusschock Feb 2, 2021
a05b291
add rank to torchelastic
justusschock Feb 2, 2021
4388e73
fix memory mixed precision
awaelchli Feb 2, 2021
be9d029
setstate on trainer for pickling in ddp spawn
awaelchli Feb 2, 2021
a90a160
add predict method
awaelchli Feb 2, 2021
767bee0
add back commented accelerator code
awaelchli Feb 2, 2021
f771a7f
adapt test for sync_batch_norm to new plugin
awaelchli Feb 3, 2021
1a3b04e
fix deprecated tests
awaelchli Feb 3, 2021
a1f4938
fix ddp cpu choice when no num_processes are given
awaelchli Feb 3, 2021
38bc8b7
Merge branch 'release/1.2-dev' into accelerator-refactor-sharded
awaelchli Feb 3, 2021
ce6b6de
yapf format
awaelchli Feb 3, 2021
3b7c20b
skip a memory test that cannot pass anymore
awaelchli Feb 3, 2021
1d26c9b
update on comments
tchaton Feb 3, 2021
f538c75
fix pickle error in spawn plugin
awaelchli Feb 3, 2021
b44d82e
x
awaelchli Feb 3, 2021
3820e77
avoid
awaelchli Feb 3, 2021
08ae327
x
awaelchli Feb 3, 2021
7d0e094
avoid tons of warnings from importing deprecated modules
awaelchli Feb 3, 2021
1028011
fix cyclic import in docs build
awaelchli Feb 3, 2021
11bd0d6
add support for sharded
justusschock Feb 4, 2021
6bf0b60
update typing
justusschock Feb 4, 2021
f94082b
add sharded and sharded_spawn to distributed types
justusschock Feb 4, 2021
7939b99
make unwrap model default
justusschock Feb 4, 2021
9131ffb
refactor LightningShardedDataParallel similar to LightningDistributed…
justusschock Feb 4, 2021
ed7425c
update sharded spawn to reflect changes
justusschock Feb 4, 2021
209a164
update sharded to reflect changes
justusschock Feb 4, 2021
837a070
Merge 1.1.5 changes
awaelchli Feb 4, 2021
136b321
fix merge
awaelchli Feb 4, 2021
ffcb535
fix merge
awaelchli Feb 4, 2021
1edfa73
yapf isort
awaelchli Feb 4, 2021
a689b81
merge 1.1.6
awaelchli Feb 4, 2021
330b14c
fix merge
awaelchli Feb 4, 2021
ef258d5
yapf isort
awaelchli Feb 4, 2021
c85000d
fix indentation in test
awaelchli Feb 4, 2021
5f3a35e
copy over reinit scheduler implementation from dev1.2
awaelchli Feb 4, 2021
fa1c9b7
fix apex tracking calls with dev_debugger
awaelchli Feb 5, 2021
e330a11
reduce diff to dev1.2, clean up
awaelchli Feb 5, 2021
994ac82
fix trainer config test when gpus>0 and num_processes >0 and ddp_cpu
awaelchli Feb 5, 2021
1a78601
sort plugin tests legacy/new
awaelchli Feb 6, 2021
4b76448
fix error handling for amp on cpu
awaelchli Feb 6, 2021
bfd54ab
Merge branch 'release/1.2-dev' into patch117
awaelchli Feb 6, 2021
0574d22
fix merge
awaelchli Feb 6, 2021
6ef6637
Merge branch 'patch117' into accelerator-refactor-sharded
awaelchli Feb 6, 2021
9feda39
[Feat] Resolve manual_backward (#5837)
tchaton Feb 6, 2021
7bb9d9f
fix tests/accelerator tests on cpu
awaelchli Feb 6, 2021
13ae1ff
[BugFix] Resolve manual optimization (#5852)
tchaton Feb 6, 2021
fc3b4db
Merge formatting changes from 1.2 branch
awaelchli Feb 6, 2021
b437642
Remove copy trainer parameters to happen earlier within the loop and …
SeanNaren Feb 7, 2021
8c6aa83
Merge branch 'release/1.2-dev' into accelerator-refactor-sharded
Feb 7, 2021
beb980a
resovle a bug
Feb 7, 2021
7a0fd27
Accelerator refactor sharded rpc (#5854)
justusschock Feb 7, 2021
0d0ced5
resolve bug
Feb 7, 2021
1f3ab76
fix assert in rpc test
awaelchli Feb 7, 2021
f1b1121
resolve a test
Feb 7, 2021
cd31fa1
fix docs compilation
awaelchli Feb 8, 2021
f48793e
accelerator refactor - fix for sharded parity test (#5866)
awaelchli Feb 8, 2021
81ff6ea
Remove DDP2 as this does not apply
Feb 8, 2021
20deb46
Add missing pre optimizer hook to ensure lambda closure is called
Feb 8, 2021
be4d1a2
Merge branch 'release/1.2-dev' into accelerator-refactor-sharded
Feb 8, 2021
0ac5fc4
fix apex docstring
awaelchli Feb 8, 2021
07fdd95
[accelerator][BugFix] Resolve some test for 1 gpu (#5863)
tchaton Feb 8, 2021
384b791
yapf isort
awaelchli Feb 8, 2021
b1a84b8
resolve flake8
tchaton Feb 8, 2021
a157a29
fix apex doctests
awaelchli Feb 8, 2021
08cfc65
fix apex doctests 2
awaelchli Feb 8, 2021
7888bfd
resolve docs
tchaton Feb 8, 2021
b5b4243
update drone
tchaton Feb 8, 2021
93ceb4c
Merge branch 'accelerator-refactor-sharded' of https://github.com/PyT…
tchaton Feb 8, 2021
d001bcf
clean env
Feb 8, 2021
ad47f47
Merge branch 'release/1.2-dev' into accelerator-refactor-sharded
tchaton Feb 8, 2021
60bfb1a
Merge branch 'release/1.2-dev' into accelerator-refactor-sharded
tchaton Feb 8, 2021
0608a41
update
Feb 8, 2021
f0120b5
update
Feb 8, 2021
bf8874e
Merge branch 'accelerator-refactor-sharded' of https://github.com/PyT…
Feb 8, 2021
baf7d7f
update
tchaton Feb 8, 2021
9360aad
update
tchaton Feb 8, 2021
b814cdc
merge
justusschock Feb 9, 2021
0d3ea37
Merge branch 'accelerator-refactor-sharded' of github.com:PytorchLigh…
justusschock Feb 9, 2021
f1f90c2
Fix RPC related tests, clean out old API, update for new accelerator …
SeanNaren Feb 9, 2021
6d05881
Merge branch 'release/1.2-dev' into accelerator-refactor-sharded
justusschock Feb 10, 2021
d86fdff
Update test_remove_1-4.py
justusschock Feb 10, 2021
5fbc1cf
Expose properties for tpu cores/gpus/num_gpus
Feb 10, 2021
aa9aea0
Add root GPU property
Feb 10, 2021
c35baf1
Move properties to properties.py
Feb 10, 2021
a9c6e21
Merge branch 'release/1.2-dev' into accelerator-refactor-sharded
awaelchli Feb 10, 2021
8f3947b
move tests that were previously in drone
awaelchli Feb 10, 2021
50ecc4a
Fix root GPU property (#5908)
SeanNaren Feb 10, 2021
c7d0075
fix best model path transfer when no checkpoint callback available
awaelchli Feb 10, 2021
3f61d15
Merge remote-tracking branch 'original/accelerator-refactor-sharded' …
awaelchli Feb 10, 2021
061ea46
Fix setup hook order [wip] (#5858)
SeanNaren Feb 10, 2021
1fe1f91
rename ddp sequential -> rpc sequential for special test
awaelchli Feb 10, 2021
3683f5a
Merge branch 'release/1.2-dev' into accelerator-refactor-sharded
awaelchli Feb 10, 2021
1f01b81
revert
awaelchli Feb 10, 2021
135c236
fix stupid merge problem
awaelchli Feb 10, 2021
222653d
Use property in connector for sampler (#5913)
SeanNaren Feb 10, 2021
f4311cd
Merge branch 'release/1.2-dev' into accelerator-refactor-sharded
awaelchli Feb 11, 2021
b210dee
merge the import conflicts
awaelchli Feb 11, 2021
236009e
fix spawning of processes in slurm
awaelchli Feb 11, 2021
aace276
[wip] Fix some bugs for TPU [skip ci] (#5878)
tchaton Feb 11, 2021
68273f5
resolve some tests
Feb 11, 2021
ca77fa4
update
Feb 11, 2021
c35edfd
Merge branch 'release/1.2-dev' into accelerator-refactor-sharded
justusschock Feb 11, 2021
8cacef7
fix imports
justusschock Feb 11, 2021
f7bbe48
update
Feb 11, 2021
30d9800
Merge branch 'accelerator-refactor-sharded' of https://github.com/PyT…
Feb 11, 2021
25f7f13
resolve flake8
tchaton Feb 11, 2021
fa28c41
update azure pipeline
tchaton Feb 11, 2021
51c27e6
Merge branch 'release/1.2-dev' into accelerator-refactor-sharded
tchaton Feb 11, 2021
b888d68
skip a sharded test on cpu that requires a gpu
awaelchli Feb 11, 2021
01ca4cd
resolve tpus
Feb 11, 2021
181d143
Merge branch 'master' into accelerator-refactor-sharded
justusschock Feb 11, 2021
946a1e9
resolve bug
Feb 11, 2021
2ad1a6e
Merge branch 'accelerator-refactor-sharded' of https://github.com/PyT…
Feb 11, 2021
6e0aff0
resolve flake8
tchaton Feb 11, 2021
a931791
update
Feb 11, 2021
319d034
Merge branch 'accelerator-refactor-sharded' of https://github.com/PyT…
Feb 11, 2021
4117bec
updat utils
Feb 11, 2021
8d000f7
Merge branch 'master' into accelerator-refactor-sharded
tchaton Feb 11, 2021
0b1ba67
revert permission change on files
awaelchli Feb 11, 2021
cc385b4
suggestions from carlos
awaelchli Feb 11, 2021
e9eb318
remove unrelated formatting changes
awaelchli Feb 11, 2021
7c08400
remove incomplete comment
awaelchli Feb 11, 2021
7c3d184
Update pytorch_lightning/accelerators/__init__.py
awaelchli Feb 11, 2021
503426e
remove unrelated formatting change
awaelchli Feb 11, 2021
c0fbf7a
add types
awaelchli Feb 11, 2021
23a9a10
warn 1.7 ddp manual backward only if ddp kwarg unset
awaelchli Feb 11, 2021
a70ee4a
yapf + isort
awaelchli Feb 11, 2021
b0621c4
pep8 unused imports
awaelchli Feb 11, 2021
18bfe70
Merge branch 'master' into accelerator-refactor-sharded
awaelchli Feb 11, 2021
7b0515d
fix cyclic import in docs
awaelchli Feb 12, 2021
d966057
Apply suggestions from code review
Borda Feb 12, 2021
f636d9d
typer in accelerator.py
Borda Feb 12, 2021
5579ea7
typo
tchaton Feb 12, 2021
f5df88b
Apply suggestions from code review
Borda Feb 12, 2021
233694e
formatting
Borda Feb 12, 2021
a47644a
update on comments
tchaton Feb 12, 2021
80dacb6
update typo
tchaton Feb 12, 2021
99573eb
Update pytorch_lightning/trainer/properties.py
tchaton Feb 12, 2021
ab859d7
update
tchaton Feb 12, 2021
0a633cb
Merge branch 'accelerator-refactor-sharded' into feat/5769_manual_opt…
tchaton Feb 12, 2021
4fb36da
update on comments
tchaton Feb 12, 2021
a578ac9
Merge branch 'master' into feat/5769_manual_optimization
awaelchli Feb 13, 2021
00055ac
Merge branch 'master' into feat/5769_manual_optimization
tchaton Feb 13, 2021
a9cdc4e
resolve some comments
tchaton Feb 13, 2021
c219416
Merge branch 'feat/5769_manual_optimization' of https://github.com/Py…
tchaton Feb 13, 2021
5760e12
update on comments
tchaton Feb 13, 2021
09d1f24
resolve test
tchaton Feb 13, 2021
ca71e62
add toggle_model
tchaton Feb 13, 2021
9519a31
update
tchaton Feb 13, 2021
68f5082
update on comments
tchaton Feb 13, 2021
d831931
update doc
tchaton Feb 13, 2021
559972f
typo
tchaton Feb 13, 2021
b5a1e55
update
tchaton Feb 13, 2021
00b9b99
typo
tchaton Feb 13, 2021
c2e79f8
remove space
tchaton Feb 13, 2021
79e6e8e
Merge branch 'master' into feat/5769_manual_optimization
mergify[bot] Feb 13, 2021
9893e4c
update
tchaton Feb 13, 2021
14e5499
Merge branch 'feat/5769_manual_optimization' of https://github.com/Py…
tchaton Feb 13, 2021
d7d7ec9
Merge branch 'master' into feat/5769_manual_optimization
mergify[bot] Feb 13, 2021
26a592f
Merge branch 'master' into feat/5769_manual_optimization
mergify[bot] Feb 13, 2021
652164c
Merge branch 'master' into feat/5769_manual_optimization
mergify[bot] Feb 13, 2021
d0f5875
update on comments
tchaton Feb 13, 2021
f880878
Merge branch 'master' into feat/5769_manual_optimization
mergify[bot] Feb 13, 2021
2e2aed9
Merge branch 'master' into feat/5769_manual_optimization
mergify[bot] Feb 13, 2021
e9ca4ab
update on comments
tchaton Feb 13, 2021
f5dfab0
Merge branch 'feat/5769_manual_optimization' of https://github.com/Py…
tchaton Feb 13, 2021
2454723
Merge branch 'master' into feat/5769_manual_optimization
mergify[bot] Feb 14, 2021
32795e5
Merge branch 'master' into feat/5769_manual_optimization
mergify[bot] Feb 14, 2021
6a44f22
Merge branch 'master' into feat/5769_manual_optimization
mergify[bot] Feb 15, 2021
e78efc4
Merge branch 'master' into feat/5769_manual_optimization
mergify[bot] Feb 15, 2021
bcd0388
update
tchaton Feb 15, 2021
8084243
Merge branch 'master' into feat/5769_manual_optimization
mergify[bot] Feb 15, 2021
315201a
Merge branch 'master' into feat/5769_manual_optimization
mergify[bot] Feb 15, 2021
86b8d98
Merge branch 'master' into feat/5769_manual_optimization
mergify[bot] Feb 15, 2021
9e3c333
Merge branch 'master' into feat/5769_manual_optimization
mergify[bot] Feb 15, 2021
684098f
Merge branch 'master' into feat/5769_manual_optimization
mergify[bot] Feb 15, 2021
5d27b18
Merge branch 'master' into feat/5769_manual_optimization
mergify[bot] Feb 15, 2021
5dd1c9b
Merge branch 'master' into feat/5769_manual_optimization
mergify[bot] Feb 15, 2021
84ec28a
Merge branch 'master' into feat/5769_manual_optimization
mergify[bot] Feb 15, 2021
faa96e9
Merge branch 'master' into feat/5769_manual_optimization
mergify[bot] Feb 16, 2021
a4a0985
Merge branch 'master' into feat/5769_manual_optimization
mergify[bot] Feb 16, 2021
e4074aa
Merge branch 'master' into feat/5769_manual_optimization
mergify[bot] Feb 16, 2021
e70fefe
Merge branch 'master' into feat/5769_manual_optimization
mergify[bot] Feb 16, 2021
869a46d
Merge branch 'master' into feat/5769_manual_optimization
mergify[bot] Feb 16, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Merge branch 'release/1.2-dev' into accelerator-refactor-sharted-4
awaelchli committed Feb 1, 2021
commit f017a397a954acaac8edc59f94c13baa2dd5e5e9
77 changes: 34 additions & 43 deletions pytorch_lightning/accelerators/accelerator_connector.py
Original file line number Diff line number Diff line change
@@ -21,16 +21,12 @@
from pytorch_lightning.accelerators.cpu import CPUAccelerator
from pytorch_lightning.accelerators.gpu import GPUAccelerator
from pytorch_lightning.accelerators.tpu import TPUAccelerator
from pytorch_lightning.cluster_environments.slurm_environment import SLURMEnvironment
from pytorch_lightning.cluster_environments.torchelastic_environment import TorchElasticEnvironment
from pytorch_lightning.plugins import (
ApexMixedPrecisionPlugin,
DataParallelPlugin,
DDP2Plugin,
DDPPlugin,
DDPShardedPlugin,
DDPSpawnPlugin,
DDPSpawnShardedPlugin,
HorovodPlugin,
NativeMixedPrecisionPlugin,
PrecisionPlugin,
@@ -40,10 +36,13 @@
TPUHalfPrecisionPlugin,
TPUSpawnPlugin,
)
from pytorch_lightning.plugins.environments import SLURMEnvironment, TorchElasticEnvironment
from pytorch_lightning.tuner.auto_gpu_select import pick_multiple_gpus
from pytorch_lightning.utilities import (
_APEX_AVAILABLE,
_HOROVOD_AVAILABLE,
_NATIVE_AMP_AVAILABLE,
_TPU_AVAILABLE,
AMPType,
device_parser,
DeviceType,
@@ -53,39 +52,28 @@
from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException

try:
import torch_xla.core.xla_model as xm
except ImportError:
XLA_AVAILABLE = False
else:
XLA_AVAILABLE = True

try:
if _HOROVOD_AVAILABLE:
import horovod.torch as hvd
except (ModuleNotFoundError, ImportError):
_HOROVOD_AVAILABLE = False
else:
_HOROVOD_AVAILABLE = True


class BackendConnector(object):

def __init__(
self,
num_processes,
tpu_cores,
distributed_backend,
auto_select_gpus,
gpus,
num_nodes,
sync_batchnorm,
benchmark,
replace_sampler_ddp,
deterministic,
precision,
amp_type,
amp_level,
cluster_environment,
self,
num_processes,
tpu_cores,
distributed_backend,
auto_select_gpus,
gpus,
num_nodes,
sync_batchnorm,
benchmark,
replace_sampler_ddp,
deterministic,
precision,
amp_type,
amp_level,
cluster_environment,
):
# initialization
self._device_type = DeviceType.CPU
@@ -102,7 +90,7 @@ def __init__(
self.replace_sampler_ddp = replace_sampler_ddp
self.deterministic = deterministic
self.precision = precision
self.amp_type = None if amp_type is None else amp_type.lower()
self.amp_type = amp_type.lower() if isinstance(amp_type, str) else None
self.amp_level = amp_level
self.cluster_environment = cluster_environment
self.is_slurm_managing_tasks = False
@@ -203,7 +191,9 @@ def parallel_devices(self):
if self.on_gpu:
devices = [torch.device("cuda", i) for i in self.parallel_device_ids]
elif self.on_tpu:
devices = [xm.xla_device(i) for i in self.parallel_device_ids]
# explicitly don't make a tpu device here!
# https://github.com/PyTorchLightning/pytorch-lightning/issues/3169
devices = [i for i in self.parallel_device_ids]
else:
devices = [torch.device("cpu")] * self.num_processes
return devices
@@ -266,8 +256,8 @@ def select_training_type_plugin(self):
use_ddp_cpu_spawn = self.use_ddp and self.on_cpu
use_ddp_cpu_torch_elastic = use_ddp_cpu_spawn and self.is_using_torchelastic
use_ddp_cpu_slurm = use_ddp_cpu_spawn and self.is_slurm_managing_tasks
use_ddp_sharded = self.distributed_backend == "ddp_sharded"
use_ddp_sharded_spawn = self.distributed_backend == "ddp_sharded_spawn"
# use_ddp_sharded = self.distributed_backend == "ddp_sharded"
# use_ddp_sharded_spawn = self.distributed_backend == "ddp_sharded_spawn"

if self.on_tpu:
ddp_plugin_cls = TPUSpawnPlugin
@@ -277,11 +267,12 @@ def select_training_type_plugin(self):
if os.environ.get("PL_IN_DDP_SUBPROCESS", False):
use_torchelastic_ddp = False

if use_ddp_sharded:
ddp_plugin_cls = DDPShardedPlugin
elif use_ddp_sharded_spawn:
ddp_plugin_cls = DDPSpawnShardedPlugin
elif use_ddp_cpu_slurm or use_slurm_ddp or use_ddp_cpu_torch_elastic or use_torchelastic_ddp:
# fixme
# if use_ddp_sharded:
# ddp_plugin_cls = DDPShardedPlugin
# elif use_ddp_sharded_spawn:
# ddp_plugin_cls = DDPSpawnShardedPlugin
if use_ddp_cpu_slurm or use_slurm_ddp or use_ddp_cpu_torch_elastic or use_torchelastic_ddp:
ddp_plugin_cls = DDPPlugin
elif use_ddp_spawn or use_ddp_cpu_spawn:
ddp_plugin_cls = DDPSpawnPlugin
@@ -388,8 +379,8 @@ def set_distributed_mode(self):

# for DDP overwrite nb processes by requested GPUs
if (
self._device_type == DeviceType.GPU
and self._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN)
self._device_type == DeviceType.GPU
and self._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN)
):
self.num_processes = self.num_gpus

@@ -407,7 +398,7 @@ def set_distributed_mode(self):

rank_zero_info(f'GPU available: {torch.cuda.is_available()}, used: {self._device_type == DeviceType.GPU}')
num_cores = self.tpu_cores if self.tpu_cores is not None else 0
rank_zero_info(f'TPU available: {XLA_AVAILABLE}, using: {num_cores} TPU cores')
rank_zero_info(f'TPU available: {_TPU_AVAILABLE}, using: {num_cores} TPU cores')

if torch.cuda.is_available() and self._device_type != DeviceType.GPU:
rank_zero_warn("GPU available but not used. Set the --gpus flag when calling the script.")
5 changes: 3 additions & 2 deletions pytorch_lightning/accelerators/cpu.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.plugins import MixedPrecisionPlugin
from pytorch_lightning.plugins.precision import MixedPrecisionPlugin
from pytorch_lightning.utilities.exceptions import MisconfigurationException


class CPUAccelerator(Accelerator):

def setup(self, trainer, model):
if isinstance(self.precision_plugin, MixedPrecisionPlugin):
MisconfigurationException("amp + cpu is not supported. Please use a GPU option")

if "cpu" not in str(self.root_device):
raise MisconfigurationException(f"Device should be CPU, got {self.root_device} instead")

return super().setup(trainer, model)
return super().setup(trainer, model)
3 changes: 2 additions & 1 deletion pytorch_lightning/accelerators/gpu.py
Original file line number Diff line number Diff line change
@@ -5,6 +5,7 @@


class GPUAccelerator(Accelerator):

def setup(self, trainer, model):
if "cuda" not in str(self.root_device):
raise MisconfigurationException(f"Device should be GPU, got {self.root_device} instead")
@@ -23,4 +24,4 @@ def on_train_start(self):
def on_train_end(self):
# clean up memory
with torch.cuda.device(self.root_device):
torch.cuda.empty_cache()
torch.cuda.empty_cache()
8 changes: 5 additions & 3 deletions pytorch_lightning/accelerators/tpu.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.plugins.training_type import SingleTPUPlugin, TPUSpawnPlugin
from pytorch_lightning.plugins.precision import MixedPrecisionPlugin
from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin
from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin
from pytorch_lightning.utilities.exceptions import MisconfigurationException


class TPUAccelerator(Accelerator):

def setup(self, trainer, model):
if isinstance(self.precision_plugin, MixedPrecisionPlugin):
raise MisconfigurationException(
@@ -14,4 +16,4 @@ def setup(self, trainer, model):

if not isinstance(self.training_type_plugin, (SingleTPUPlugin, TPUSpawnPlugin)):
raise MisconfigurationException("TPUs only support a single tpu core or tpu spawn training.")
return super().setup(trainer, model)
return super().setup(trainer, model)
31 changes: 29 additions & 2 deletions pytorch_lightning/plugins/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,31 @@
from pytorch_lightning.plugins.base_plugin import Plugin # noqa: F401
from pytorch_lightning.plugins.precision import *
from pytorch_lightning.plugins.training_type import *
from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.tpu_bfloat import TPUHalfPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.ddp2 import DDP2Plugin # noqa: F401
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.dp import DataParallelPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.horovod import HorovodPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin # noqa: F401
from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin # noqa: F401

__all__ = [
"ApexMixedPrecisionPlugin",
"DataParallelPlugin",
"DDP2Plugin",
"DDPPlugin",
"DDPSpawnPlugin",
"HorovodPlugin",
"NativeMixedPrecisionPlugin",
"PrecisionPlugin",
"ShardedNativeMixedPrecisionPlugin",
"SingleDevicePlugin",
"SingleTPUPlugin",
"TPUHalfPrecisionPlugin",
"TPUSpawnPlugin",
]
19 changes: 2 additions & 17 deletions pytorch_lightning/plugins/environments/torchelastic_environment.py
Original file line number Diff line number Diff line change
@@ -15,8 +15,8 @@
import os

from pytorch_lightning import _logger as log
from pytorch_lightning.cluster_environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.utilities import rank_zero_warn, rank_zero_info
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.utilities import rank_zero_warn


class TorchElasticEnvironment(ClusterEnvironment):
@@ -46,18 +46,3 @@ def world_size(self):

def local_rank(self):
return int(os.environ['LOCAL_RANK'])

def node_rank(self):
# TODO: use GROUP_RANK and provide a default environment class that uses NODE_RANK
# torchelastic uses the envvar GROUP_RANK, whereas other systems(?) use NODE_RANK.
# otherwise use given node rank or default to node rank 0
env_vars = ['NODE_RANK', 'GROUP_RANK']
node_ids = [(k, os.environ.get(k, None)) for k in env_vars]
node_ids = [(k, v) for k, v in node_ids if v is not None]
if len(node_ids) == 0:
return 0
if len(node_ids) > 1:
log.warning(f"Multiple environment variables ({node_ids}) defined for node rank. Using the first one.")
k, rank = node_ids.pop()
rank_zero_info(f"Using environment variable {k} for node rank ({rank}).")
return int(rank)
4 changes: 2 additions & 2 deletions pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
@@ -24,8 +24,8 @@

from pytorch_lightning import _logger as log
from pytorch_lightning.distributed import LightningDistributed
from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.overrides.data_parallel import unwrap_lightning_module
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
from pytorch_lightning.utilities import _HYDRA_AVAILABLE
from pytorch_lightning.utilities.distributed import (
4 changes: 2 additions & 2 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
@@ -22,8 +22,8 @@

from pytorch_lightning import _logger as log
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.overrides.data_parallel import unwrap_lightning_module
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.cloud_io import load as pl_load
28 changes: 16 additions & 12 deletions pytorch_lightning/plugins/training_type/single_device.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,48 @@
from typing import Any, Union

import torch
from torch._C import device

from pytorch_lightning.plugins .training_type.training_type_plugin import TrainingTypePlugin
from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin


class SingleDevicePlugin(TrainingTypePlugin):
def __init__(self, device):

def __init__(self, device: torch.device) -> bool:
super().__init__()
self.device: torch.device = device

@property
def on_tpu(self):
return self.device.type == 'xla'
def on_tpu(self) -> bool:
return False

@property
def on_gpu(self):
def on_gpu(self) -> bool:
return self.device.type == "cuda" and torch.cuda.is_available()

def reduce(self, output, *args, **kwargs):
def reduce(self, output: Union[Any, torch.Tensor], *args: Any, **kwargs: Any) -> Union[Any, torch.Tensor]:
return output

@property
def root_device(self):
def root_device(self) -> torch.device:
return self.device
def model_to_device(self):

def model_to_device(self) -> None:
if self.on_gpu:
torch.cuda.set_device(self.root_device)

self._model.to(self.root_device)

def connect(self, model: torch.nn.Module):
def connect(self, model: torch.nn.Module) -> torch.nn.Module:
self._model = model
self.model_to_device()
return self.model

@property
def is_global_zero(self):
def is_global_zero(self) -> bool:
return True

def barrier(self, *args, **kwargs):
def barrier(self, *args, **kwargs) -> None:
pass

def broadcast(self, obj: object, src: int = 0) -> object:
Loading
You are viewing a condensed version of this merge commit. You can view the full changes here.