Skip to content

Commit

Permalink
Fixes #2936 (no fix needed) (#3892)
Browse files Browse the repository at this point in the history
  • Loading branch information
williamFalcon authored Oct 6, 2020
1 parent 893bed7 commit cb2a326
Showing 1 changed file with 49 additions and 1 deletion.
50 changes: 49 additions & 1 deletion tests/trainer/test_optimizers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import pytest
import torch

from pytorch_lightning import Trainer
from pytorch_lightning import Trainer, Callback
from tests.base import EvalModelTemplate
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base.boring_model import BoringModel


def test_optimizer_with_scheduling(tmpdir):
Expand Down Expand Up @@ -298,3 +299,50 @@ def test_init_optimizers_during_testing(tmpdir):
assert len(trainer.lr_schedulers) == 0
assert len(trainer.optimizers) == 0
assert len(trainer.optimizer_frequencies) == 0


def test_multiple_optimizers_callbacks(tmpdir):
"""
Tests that multiple optimizers can be used with callbacks
"""
class CB(Callback):

def on_train_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
pass

def on_train_epoch_start(self, trainer, pl_module):
pass

class TestModel(BoringModel):
def __init__(self):
super().__init__()
self.layer_1 = torch.nn.Linear(32, 2)
self.layer_2 = torch.nn.Linear(32, 2)

def training_step(self, batch, batch_idx, optimizer_idx):
if optimizer_idx == 0:
a = batch[0]
acc = self.layer_1(a)
else:
a = batch[0]
acc = self.layer_2(a)

acc = self.loss(acc, acc)
return acc

def configure_optimizers(self):
a = torch.optim.RMSprop(self.layer_1.parameters(), 1e-2)
b = torch.optim.RMSprop(self.layer_2.parameters(), 1e-2)
return a, b

model = TestModel()
model.training_epoch_end = None
trainer = Trainer(
callbacks=[CB()],
default_root_dir=tmpdir,
limit_train_batches=1,
limit_val_batches=2,
max_epochs=1,
weights_summary=None,
)
trainer.fit(model)

0 comments on commit cb2a326

Please sign in to comment.