Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check if optimizer supports closure #4981

Merged
merged 35 commits into from
Dec 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
e4fe713
check if optimizer support closure
tchaton Dec 4, 2020
46712a0
cleanup test
tchaton Dec 4, 2020
1b77a87
Merge branch 'master' into bugfix/resolve_lightning_optimizer
tchaton Dec 4, 2020
1c8df24
resolve tests
tchaton Dec 4, 2020
b4abb72
Merge branch 'bugfix/resolve_lightning_optimizer' of https://github.c…
tchaton Dec 4, 2020
96adb50
resolve flake
tchaton Dec 4, 2020
7663f5c
Merge branch 'master' into bugfix/resolve_lightning_optimizer
tchaton Dec 4, 2020
4cb08ea
update test due to patch limit
tchaton Dec 4, 2020
ecb1a76
Merge branch 'master' into bugfix/resolve_lightning_optimizer
tchaton Dec 5, 2020
1dfa83b
Merge branch 'master' into bugfix/resolve_lightning_optimizer
tchaton Dec 5, 2020
b5661b4
update
tchaton Dec 6, 2020
842d31e
Merge branch 'master' into bugfix/resolve_lightning_optimizer
tchaton Dec 6, 2020
d4a8bb5
update dep
tchaton Dec 6, 2020
cab52c9
Merge branch 'bugfix/resolve_lightning_optimizer' of https://github.c…
tchaton Dec 6, 2020
d1d6ee0
Merge branch 'master' into bugfix/resolve_lightning_optimizer
tchaton Dec 7, 2020
94131bd
Update tests/core/test_lightning_optimizer.py
tchaton Dec 9, 2020
33cf178
Update tests/core/test_lightning_optimizer.py
tchaton Dec 9, 2020
8ebe84e
Merge branch 'master' into bugfix/resolve_lightning_optimizer
tchaton Dec 9, 2020
842cad6
resolve bug
tchaton Dec 9, 2020
715b820
update test
tchaton Dec 9, 2020
4cd8d74
Merge branch 'master' into bugfix/resolve_lightning_optimizer
tchaton Dec 9, 2020
13dfeaf
resolve tests
tchaton Dec 9, 2020
f2d8b72
Update requirements/extra.txt
tchaton Dec 9, 2020
d2a5b44
Merge branch 'master' into bugfix/resolve_lightning_optimizer
tchaton Dec 9, 2020
97727ef
Merge branch 'master' into bugfix/resolve_lightning_optimizer
tchaton Dec 10, 2020
f1fcd26
Merge branch 'master' into bugfix/resolve_lightning_optimizer
tchaton Dec 10, 2020
9d0b4ba
Merge branch 'master' into bugfix/resolve_lightning_optimizer
tchaton Dec 11, 2020
59bf595
remove bolts dep
tchaton Dec 11, 2020
80c039a
Merge branch 'bugfix/resolve_lightning_optimizer' of https://github.c…
tchaton Dec 11, 2020
3b341ec
Merge branch 'master' into bugfix/resolve_lightning_optimizer
tchaton Dec 11, 2020
76a1022
remove bolts
tchaton Dec 11, 2020
4aace55
Merge branch 'master' into bugfix/resolve_lightning_optimizer
tchaton Dec 11, 2020
5b52e85
add missing bolts dep for tests
tchaton Dec 11, 2020
47c848a
Merge branch 'bugfix/resolve_lightning_optimizer' of https://github.c…
tchaton Dec 11, 2020
83faead
remove need for bolts
tchaton Dec 11, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions pytorch_lightning/core/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import types
from typing import Any, Callable, Optional
from weakref import proxy
Expand Down Expand Up @@ -60,7 +61,7 @@ def __init__(self,
self._trainer = None
self._optimizer = optimizer
self._accumulate_grad_batches = accumulate_grad_batches
self._automatic_optimization = None
self._support_closure = 'closure' in inspect.signature(optimizer.step).parameters
self._optimizer_idx = None

@property
Expand All @@ -73,7 +74,6 @@ def accumulate_grad_batches(self, accumulate_grad_batches):

def _on_trainer_init(self, trainer):
self._trainer = proxy(trainer)
self._automatic_optimization = trainer.train_loop.automatic_optimization
for opt_idx, opt in enumerate(trainer.optimizers):
if opt == self._optimizer:
self._optimizer_idx = opt_idx
Expand Down Expand Up @@ -111,7 +111,11 @@ def __optimizer_step(self, *args, closure: Optional[Callable] = None, profiler_n

else:
with trainer.profiler.profile(profiler_name):
optimizer.step(closure=closure, *args, **kwargs)
if self._support_closure:
optimizer.step(closure=closure, *args, **kwargs)
else:
closure()
optimizer.step(*args, **kwargs)

accelerator_backend = trainer.accelerator_backend
if accelerator_backend is not None and accelerator_backend.rpc_enabled:
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def _module_available(module_path: str) -> bool:
OMEGACONF_AVAILABLE = _module_available("omegaconf")
HYDRA_AVAILABLE = _module_available("hydra")
HOROVOD_AVAILABLE = _module_available("horovod.torch")
BOLTS_AVAILABLE = _module_available("pl_bolts")

TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists()
FAIRSCALE_AVAILABLE = platform.system() != 'Windows' and _module_available('fairscale.nn.data_parallel')
Expand Down
2 changes: 1 addition & 1 deletion requirements/extra.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ torchtext>=0.3.1, <0.7 # TODO: temporary fix fix for compatibility
onnx>=1.7.0
onnxruntime>=1.3.0
hydra-core>=1.0
https://github.com/PyTorchLightning/fairscale/archive/pl_1.1.0.zip
https://github.com/PyTorchLightning/fairscale/archive/pl_1.1.0.zip
5 changes: 5 additions & 0 deletions tests/core/test_lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ def test_automatic_optimization_num_calls(enable_pl_optimizer, tmpdir):

class TestModel(BoringModel):

def training_step(self, batch, batch_idx, optimizer_idx):
output = self.layer(batch)
loss = self.loss(batch, output)
return {"loss": loss}

def configure_optimizers(self):
optimizer = SGD(self.layer.parameters(), lr=0.1)
optimizer_2 = Adam(self.layer.parameters(), lr=0.1)
Expand Down
74 changes: 62 additions & 12 deletions tests/core/test_lightning_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@
import torch.nn as nn
from torch.optim import Adam, Optimizer

import pytorch_lightning as pl
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base.boring_model import BoringModel, RandomDictDataset, RandomDictStringDataset
from tests.base.boring_model import BoringModel, RandomDataset, RandomDictDataset, RandomDictStringDataset


def test_lightning_optimizer(tmpdir):
Expand Down Expand Up @@ -80,8 +82,8 @@ def configure_optimizers(self):
assert trainer.optimizers[0].__repr__() == expected


@patch("torch.optim.Adam.step")
@patch("torch.optim.SGD.step")
@patch("torch.optim.Adam.step", autospec=True)
@patch("torch.optim.SGD.step", autospec=True)
def test_lightning_optimizer_manual_optimization(mock_sgd_step, mock_adam_step, tmpdir):
"""
Test that the user can use our LightningOptimizer. Not recommended for now.
Expand All @@ -96,13 +98,13 @@ def training_step(self, batch, batch_idx, optimizer_idx=None):
output = self.layer(batch)
loss_1 = self.loss(batch, output)
self.manual_backward(loss_1, opt_1)
opt_1.step(idx="1")
opt_1.step()

def closure():
output = self.layer(batch)
loss_2 = self.loss(batch, output)
self.manual_backward(loss_2, opt_2)
opt_2.step(closure=closure, idx="2")
opt_2.step(closure=closure)

def configure_optimizers(self):
optimizer_1 = torch.optim.SGD(self.layer.parameters(), lr=0.1)
Expand Down Expand Up @@ -133,8 +135,8 @@ def automatic_optimization(self) -> bool:
assert len(mock_adam_step.mock_calls) == 8


@patch("torch.optim.Adam.step")
@patch("torch.optim.SGD.step")
@patch("torch.optim.Adam.step", autospec=True)
@patch("torch.optim.SGD.step", autospec=True)
def test_lightning_optimizer_manual_optimization_and_accumulated_gradients(mock_sgd_step, mock_adam_step, tmpdir):
"""
Test that the user can use our LightningOptimizer. Not recommended.
Expand All @@ -149,13 +151,13 @@ def training_step(self, batch, batch_idx, optimizer_idx=None):
output = self.layer(batch)
loss_1 = self.loss(batch, output)
self.manual_backward(loss_1, opt_1)
opt_1.step(idx="1")
opt_1.step()

def closure():
output = self.layer(batch)
loss_2 = self.loss(batch, output)
self.manual_backward(loss_2, opt_2)
opt_2.step(closure=closure, idx="2")
opt_2.step(closure=closure)

def configure_optimizers(self):
optimizer_1 = torch.optim.SGD(self.layer.parameters(), lr=0.1)
Expand Down Expand Up @@ -195,9 +197,8 @@ def test_state(tmpdir):
assert isinstance(lightning_optimizer, Adam)
assert isinstance(lightning_optimizer, Optimizer)
lightning_dict = {}
special_attrs = ["_accumulate_grad_batches", "_optimizer", "_optimizer_idx",
"_trainer", "_use_accumulate_grad_batches_from_trainer", "_automatic_optimization",
"_accumulate_grad_batches"]
special_attrs = ["_accumulate_grad_batches", "_optimizer", "_optimizer_idx", "_support_closure",
"_trainer"]
for k, v in lightning_optimizer.__dict__.items():
if k not in special_attrs:
lightning_dict[k] = v
Expand All @@ -206,6 +207,55 @@ def test_state(tmpdir):
assert optimizer.state == lightning_optimizer.state


def test_lightning_optimizer_with_wrong_optimizer_interface(tmpdir):
class OptimizerWrapper(object):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, optimizer):
self.optim = optimizer
self.state_dict = self.optim.state_dict
self.load_state_dict = self.optim.load_state_dict
self.zero_grad = self.optim.zero_grad
self.add_param_group = self.optim.add_param_group
self.__setstate__ = self.optim.__setstate__
self.__getstate__ = self.optim.__getstate__
self.__repr__ = self.optim.__repr__

@property
def __class__(self):
return Optimizer

@property
def state(self):
return self.optim.state

@property
def param_groups(self):
return self.optim.param_groups

@param_groups.setter
def param_groups(self, value):
self.optim.param_groups = value

def step(self):
# wrongly defined step. Should contain closure
self.optim.step(closure=None)

class TestLightningOptimizerModel(BoringModel):

def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=0.1)
optimizer = OptimizerWrapper(optimizer)
return [optimizer]

model = TestLightningOptimizerModel()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
weights_summary=None,
log_every_n_steps=1,
)
trainer.fit(model)


def test_lightning_optimizer_automatic_optimization(tmpdir):
"""
Test lightning optimize works with make_optimizer_step in automatic_optimization
Expand Down
11 changes: 5 additions & 6 deletions tests/trainer/optimization/test_manual_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,7 +825,7 @@ def optimizer_closure():
retain_graph = num_backward != backward_idx # noqa E225
self.manual_backward(loss_1, opt, retain_graph=retain_graph)

opt.step(1, closure=optimizer_closure, something="new")
opt.step(closure=optimizer_closure)

def training_epoch_end(self, outputs) -> None:
# outputs should be an array with an entry per optimizer
Expand Down Expand Up @@ -855,7 +855,7 @@ def automatic_optimization(self) -> bool:
)

trainer.fit(model)
expected_calls = [call(1, closure=ANY, something="new") for s in range(2)]
expected_calls = [call() for s in range(2)]
step_mock.assert_has_calls(expected_calls)


Expand Down Expand Up @@ -902,7 +902,7 @@ def dis_closure():
if batch_idx % 4 == 0 :
# Note: Set make_optimizer_step to True or it will use by default
# Trainer(accumulate_grad_batches=x)
opt_dis.step(closure=dis_closure, make_optimizer_step=True, optim='adam')
opt_dis.step(closure=dis_closure, make_optimizer_step=True)

def training_epoch_end(self, outputs) -> None:
# outputs should be an array with an entry per optimizer
Expand Down Expand Up @@ -933,10 +933,9 @@ def automatic_optimization(self) -> bool:
)

trainer.fit(model)
expected_calls = [call(closure=ANY, optim='sgd') for s in range(4)]
expected_calls = [call(optim='sgd') for s in range(4)]
mock_sgd_step.assert_has_calls(expected_calls)

expected_calls = [call(closure=ANY, optim='adam') for s in range(2)]
expected_calls = [call() for s in range(2)]
mock_adam_step.assert_has_calls(expected_calls)


Expand Down