Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

wrap tpu tests with process decorator #2582

Closed
wants to merge 11 commits into from
14 changes: 10 additions & 4 deletions pytorch_lightning/accelerators/tpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@ def teardown(self):
# load last weights
if last_path and not self.trainer.testing:
ckpt = torch.load(last_path, map_location=lambda storage, loc: storage)
model.load_state_dict(ckpt)
self.trainer.model.load_state_dict(ckpt)

self.trainer.model = model
self.trainer.model = self.trainer.model

# when training completes, load the weights back in main process
self.__load_weights_on_main_process()
Expand All @@ -82,13 +82,19 @@ def teardown(self):
def train(self):
model = self.trainer.model

if self.trainer.can_prepare_data():
self.trainer.model.prepare_data()
self._is_data_prepared = True

self.trainer.barrier('prepare_data')

# train
if self.trainer.tpu_id is not None:
self.tpu_train_in_process(self.trainer.tpu_id, model, self.trainer, self.mp_queue)
self.tpu_train_in_process(self.trainer.tpu_id, self.trainer.model, self.trainer, self.mp_queue)
else:
xmp.spawn(
self.tpu_train_in_process,
args=(model, self.trainer, self.mp_queue),
args=(self.trainer.model, self.trainer, self.mp_queue),
nprocs=self.trainer.tpu_cores,
start_method=self.start_method
)
Expand Down
14 changes: 9 additions & 5 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1414,12 +1414,16 @@ def __test_given_model(self, model, test_dataloaders):

return results

def barrier(self, name):
if self.use_ddp or self.use_ddp2:
pass
# torch_distrib.barrier()
def barrier(self, name, tpu=True, gpu=True):
if gpu and (self.use_ddp or self.use_ddp2):
# sometimes, this may be called if the process group hasn't been init
# catch and ignore that exceptions
try:
torch_distrib.barrier()
except Exception as e:
pass

if self.on_tpu and XLA_AVAILABLE:
if tpu and (self.on_tpu and XLA_AVAILABLE):
# wait for all processes to catch up
torch_xla.core.xla_model.rendezvous(f'pl.Trainer.{name}')

Expand Down
5 changes: 5 additions & 0 deletions tests/base/develop_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,5 +111,10 @@ def inner_f(queue, **kwargs):

result = queue.get()
assert result == 1, 'expected 1, but returned %s' % result
assert result == 1

# cleaning
proc.close()
queue.close()

return wrapper
7 changes: 7 additions & 0 deletions tests/base/model_test_dataloaders.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from abc import ABC, abstractmethod

from tests.base.dataloaders import CustomInfDataloader
from torch.utils.data import DataLoader
from tests.base import TrialMNIST
from tests.base.dataloaders import CustomNotImplementedErrorDataloader


Expand All @@ -16,6 +18,11 @@ def test_dataloader(self):
def test_dataloader__infinite(self):
return CustomInfDataloader(self.dataloader(train=False))

def test_dataloader__long(self):
dataset = DataLoader(TrialMNIST(download=True, train=False, num_samples=15000, digits=(0, 1, 2, 5, 8)),
batch_size=32)
return dataset

def test_dataloader__not_implemented_error(self):
return CustomNotImplementedErrorDataloader(self.dataloader(train=False))

Expand Down
7 changes: 7 additions & 0 deletions tests/base/model_train_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from tests.base.dataloaders import CustomInfDataloader
from tests.base.dataloaders import CustomNotImplementedErrorDataloader
from tests.base import TrialMNIST
from torch.utils.data import DataLoader


class TrainDataloaderVariations(ABC):
Expand All @@ -16,6 +18,11 @@ def train_dataloader(self):
def train_dataloader__infinite(self):
return CustomInfDataloader(self.dataloader(train=True))

def train_dataloader__long(self):
dataset = DataLoader(TrialMNIST(download=True, num_samples=15000,
digits=(0, 1, 2, 5, 8)), batch_size=32)
return dataset

def train_dataloader__not_implemented_error(self):
return CustomNotImplementedErrorDataloader(self.dataloader(train=True))

Expand Down
7 changes: 7 additions & 0 deletions tests/base/model_valid_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from tests.base.dataloaders import CustomInfDataloader
from tests.base.dataloaders import CustomNotImplementedErrorDataloader
from tests.base import TrialMNIST
from torch.utils.data import DataLoader


class ValDataloaderVariations(ABC):
Expand All @@ -22,6 +24,11 @@ def val_dataloader__multiple(self):
return [self.dataloader(train=False),
self.dataloader(train=False)]

def val_dataloader__long(self):
dataset = DataLoader(TrialMNIST(download=True, train=False,
num_samples=15000, digits=(0, 1, 2, 5, 8)), batch_size=32)
return dataset

def val_dataloader__infinite(self):
return CustomInfDataloader(self.dataloader(train=False))

Expand Down
60 changes: 30 additions & 30 deletions tests/models/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
from torch.utils.data import DataLoader

import tests.base.develop_pipelines as tpipes
import tests.base.develop_utils as dutils
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate
from tests.base.datasets import TrialMNIST
from tests.base.develop_utils import pl_multi_process_test

try:
import torch_xla
Expand All @@ -34,16 +34,16 @@ def _serial_train_loader():


@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine")
@pl_multi_process_test
@dutils.pl_multi_process_test
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when was this function added? and never used till now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i added it a few days ago when the spawn stuff came out...

def test_model_tpu_cores_1(tmpdir):
"""Make sure model trains on TPU."""
trainer_options = dict(
default_root_dir=tmpdir,
progress_bar_refresh_rate=0,
max_epochs=1,
tpu_cores=1,
limit_train_batches=0.4,
limit_val_batches=0.4,
limit_train_batches=10,
limit_val_batches=10,
)

model = EvalModelTemplate()
Expand All @@ -52,16 +52,16 @@ def test_model_tpu_cores_1(tmpdir):

@pytest.mark.parametrize('tpu_core', [1, 5])
@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine")
@pl_multi_process_test
@dutils.pl_multi_process_test
def test_model_tpu_index(tmpdir, tpu_core):
"""Make sure model trains on TPU."""
trainer_options = dict(
default_root_dir=tmpdir,
progress_bar_refresh_rate=0,
max_epochs=1,
tpu_cores=[tpu_core],
limit_train_batches=0.4,
limit_val_batches=0.4,
limit_train_batches=10,
limit_val_batches=10,
)

model = EvalModelTemplate()
Expand All @@ -70,28 +70,28 @@ def test_model_tpu_index(tmpdir, tpu_core):


@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine")
@pl_multi_process_test
@dutils.pl_multi_process_test
def test_model_tpu_cores_8(tmpdir):
"""Make sure model trains on TPU."""
trainer_options = dict(
default_root_dir=tmpdir,
progress_bar_refresh_rate=0,
max_epochs=1,
tpu_cores=8,
limit_train_batches=0.4,
limit_val_batches=0.4,
limit_train_batches=20,
limit_val_batches=20,
)

model = EvalModelTemplate()
# 8 cores needs a big dataset
model.train_dataloader = _serial_train_loader
model.val_dataloader = _serial_train_loader
model.train_dataloader = model.train_dataloader__long
model.val_dataloader = model.val_dataloader__long

tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False)


@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine")
@pl_multi_process_test
@dutils.pl_multi_process_test
def test_model_16bit_tpu_cores_1(tmpdir):
"""Make sure model trains on TPU."""
trainer_options = dict(
Expand All @@ -100,8 +100,8 @@ def test_model_16bit_tpu_cores_1(tmpdir):
progress_bar_refresh_rate=0,
max_epochs=1,
tpu_cores=1,
limit_train_batches=0.4,
limit_val_batches=0.4,
limit_train_batches=10,
limit_val_batches=10,
)

model = EvalModelTemplate()
Expand All @@ -111,7 +111,7 @@ def test_model_16bit_tpu_cores_1(tmpdir):

@pytest.mark.parametrize('tpu_core', [1, 5])
@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine")
@pl_multi_process_test
@dutils.pl_multi_process_test
def test_model_16bit_tpu_index(tmpdir, tpu_core):
"""Make sure model trains on TPU."""
trainer_options = dict(
Expand All @@ -122,8 +122,8 @@ def test_model_16bit_tpu_index(tmpdir, tpu_core):
val_percent_check=0.2,
max_epochs=1,
tpu_cores=[tpu_core],
limit_train_batches=0.4,
limit_val_batches=0.4,
limit_train_batches=10,
limit_val_batches=10,
)

model = EvalModelTemplate()
Expand All @@ -133,7 +133,7 @@ def test_model_16bit_tpu_index(tmpdir, tpu_core):


@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine")
@pl_multi_process_test
@dutils.pl_multi_process_test
def test_model_16bit_tpu_cores_8(tmpdir):
"""Make sure model trains on TPU."""
trainer_options = dict(
Expand All @@ -142,20 +142,20 @@ def test_model_16bit_tpu_cores_8(tmpdir):
progress_bar_refresh_rate=0,
max_epochs=1,
tpu_cores=8,
limit_train_batches=0.4,
limit_val_batches=0.4,
limit_train_batches=10,
limit_val_batches=10,
)

model = EvalModelTemplate()
# 8 cores needs a big dataset
model.train_dataloader = _serial_train_loader
model.val_dataloader = _serial_train_loader
model.train_dataloader = model.train_dataloader__long
model.val_dataloader = model.val_dataloader__long

tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False)


@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine")
@pl_multi_process_test
@dutils.pl_multi_process_test
def test_model_tpu_early_stop(tmpdir):
"""Test if single TPU core training works"""
model = EvalModelTemplate()
Expand All @@ -172,16 +172,16 @@ def test_model_tpu_early_stop(tmpdir):


@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine")
@pl_multi_process_test
@dutils.pl_multi_process_test
def test_tpu_grad_norm(tmpdir):
"""Test if grad_norm works on TPU."""
trainer_options = dict(
default_root_dir=tmpdir,
progress_bar_refresh_rate=0,
max_epochs=1,
tpu_cores=1,
limit_train_batches=0.4,
limit_val_batches=0.4,
limit_train_batches=10,
limit_val_batches=10,
gradient_clip_val=0.1,
)

Expand All @@ -190,7 +190,7 @@ def test_tpu_grad_norm(tmpdir):


@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine")
@pl_multi_process_test
@dutils.pl_multi_process_test
def test_dataloaders_passed_to_fit(tmpdir):
"""Test if dataloaders passed to trainer works on TPU"""

Expand All @@ -199,7 +199,7 @@ def test_dataloaders_passed_to_fit(tmpdir):
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
tpu_cores=8
tpu_cores=8,
)
result = trainer.fit(model, train_dataloader=model.train_dataloader(), val_dataloaders=model.val_dataloader())
assert result, "TPU doesn't work with dataloaders passed to fit()."
Expand Down Expand Up @@ -244,7 +244,7 @@ def test_distributed_backend_set_when_using_tpu(tmpdir, tpu_cores):


@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine")
@pl_multi_process_test
@dutils.pl_multi_process_test
def test_result_obj_on_tpu(tmpdir):
seed_everything(1234)
os.environ['PL_DEV_DEBUG'] = '1'
Expand Down