diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md
index bbbc7a51dbb4a..159ab3cf4b26b 100644
--- a/.github/ISSUE_TEMPLATE/bug_report.md
+++ b/.github/ISSUE_TEMPLATE/bug_report.md
@@ -19,12 +19,8 @@ assignees: ''
### To Reproduce
-Steps to reproduce the behavior:
-
-1. Go to '...'
-2. Run '....'
-3. Scroll down to '....'
-4. See error
+Before reporting a bug, make sure that the bug can be reproduced with a minimal example. You can simply subclass our [miminal code example] (https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/bug_report_model.py) and add your relevant changes, to see if the issue persists.
+If the test is failing, please add your test cases to the issue (as a draft PR, or simple paste the code to the isse description here).
diff --git a/.github/workflows/ci_test-conda.yml b/.github/workflows/ci_test-conda.yml
index 2076976bacd92..a71826905888b 100644
--- a/.github/workflows/ci_test-conda.yml
+++ b/.github/workflows/ci_test-conda.yml
@@ -31,6 +31,7 @@ jobs:
pip list
- name: Cache datasets
+ # todo this probably does not work with docker images, rather cache dockers
uses: actions/cache@v2
with:
path: Datasets # This path is specific to Ubuntu
diff --git a/.github/workflows/ci_test-full.yml b/.github/workflows/ci_test-full.yml
index 960030c916f0d..2ff6ebe444094 100644
--- a/.github/workflows/ci_test-full.yml
+++ b/.github/workflows/ci_test-full.yml
@@ -103,7 +103,8 @@ jobs:
HOROVOD_BUILT=$(python -c "import horovod.torch; horovod.torch.nccl_built(); print('SUCCESS')" || true)
if [[ $HOROVOD_BUILT != "SUCCESS" ]]; then
pip uninstall -y horovod
- pip install --no-cache-dir $(grep "horovod" requirements/extra.txt)
+ echo $(grep "horovod" requirements/extra.txt) > requirements/horovod.txt
+ pip install --no-cache-dir -r requirements/horovod.txt
fi
horovodrun --check-build
shell: bash
diff --git a/.pyrightconfig.json b/.pyrightconfig.json
index cb14993e2cc0e..3f00d9a3e4454 100644
--- a/.pyrightconfig.json
+++ b/.pyrightconfig.json
@@ -35,6 +35,7 @@
"pytorch_lightning/trainer/connectors/checkpoint_connector.py",
"pytorch_lightning/trainer/connectors/data_connector.py",
"pytorch_lightning/trainer/connectors/logger_connector.py",
+ "pytorch_lightning/trainer/connectors/slurm_connector.py",
"pytorch_lightning/distributed/dist.py",
"pytorch_lightning/tuner",
"pytorch_lightning/plugins"
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 15e8573f34baa..433c754777b95 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -31,6 +31,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `broadcast` to `TPUBackend` ([#3814](https://github.com/PyTorchLightning/pytorch-lightning/pull/3814))
+- Added `XLADeviceUtils` class to check XLA device type ([#3274](https://github.com/PyTorchLightning/pytorch-lightning/pull/3274))
+
### Changed
- Changed `LearningRateLogger` to `LearningRateMonitor` ([#3251](https://github.com/PyTorchLightning/pytorch-lightning/pull/3251))
@@ -57,6 +59,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Deprecated
+- Rename Trainer arguments `row_log_interval` >> `log_every_n_steps` and `log_save_interval` >> `flush_logs_every_n_steps` ([#3748](https://github.com/PyTorchLightning/pytorch-lightning/pull/3748))
### Removed
@@ -101,10 +104,14 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed counter-intuitive error being thrown in `Accuracy` metric for zero target tensor ([#3764](https://github.com/PyTorchLightning/pytorch-lightning/pull/3764))
+- Fixed Tuner dump: add `current_epoch` to dumped_params ([#3261](https://github.com/PyTorchLightning/pytorch-lightning/pull/3261))
+
- Fixed aggregation of metrics ([#3517](https://github.com/PyTorchLightning/pytorch-lightning/pull/3517))
- Fixed `current_epoch` and `global_step` properties mismatch between `Trainer` and `LightningModule` ([#3785](https://github.com/PyTorchLightning/pytorch-lightning/pull/3785))
+- Fixed learning rate scheduler for optimizers with internal state ([#3897](https://github.com/PyTorchLightning/pytorch-lightning/pull/3897))
+
## [0.9.0] - YYYY-MM-DD
### Added
diff --git a/dockers/base-cuda/Dockerfile b/dockers/base-cuda/Dockerfile
index 2090320f3c10c..390d51ceb41e4 100644
--- a/dockers/base-cuda/Dockerfile
+++ b/dockers/base-cuda/Dockerfile
@@ -22,8 +22,11 @@
ARG CUDNN_VERSION=7
ARG CUDA_VERSION=10.1
-FROM nvidia/cuda:${CUDA_VERSION}-cudnn${CUDNN_VERSION}-devel
-# FROM nvidia/cuda:${CUDA_VERSION}-devel
+# FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04
+# FROM nvidia/cuda:${CUDA_VERSION}-cudnn${CUDNN_VERSION}-devel-ubuntu18.04
+FROM nvidia/cuda:${CUDA_VERSION}-cudnn${CUDNN_VERSION}-devel-ubuntu16.04
+# FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu18.04
+# FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu16.04
ARG PYTHON_VERSION=3.7
ARG PYTORCH_VERSION=1.6
diff --git a/docs/source/experiment_logging.rst b/docs/source/experiment_logging.rst
deleted file mode 100644
index 4ccad84ef21de..0000000000000
--- a/docs/source/experiment_logging.rst
+++ /dev/null
@@ -1,242 +0,0 @@
-.. testsetup:: *
-
- from pytorch_lightning.trainer.trainer import Trainer
- from pytorch_lightning.core.lightning import LightningModule
-
-.. _experiment_logging:
-
-Experiment Logging
-==================
-
-Comet.ml
-^^^^^^^^
-
-`Comet.ml `_ is a third-party logger.
-To use :class:`~pytorch_lightning.loggers.CometLogger` as your logger do the following.
-First, install the package:
-
-.. code-block:: bash
-
- pip install comet-ml
-
-Then configure the logger and pass it to the :class:`~pytorch_lightning.trainer.trainer.Trainer`:
-
-.. testcode::
-
- import os
- from pytorch_lightning.loggers import CometLogger
- comet_logger = CometLogger(
- api_key=os.environ.get('COMET_API_KEY'),
- workspace=os.environ.get('COMET_WORKSPACE'), # Optional
- save_dir='.', # Optional
- project_name='default_project', # Optional
- rest_api_key=os.environ.get('COMET_REST_API_KEY'), # Optional
- experiment_name='default' # Optional
- )
- trainer = Trainer(logger=comet_logger)
-
-The :class:`~pytorch_lightning.loggers.CometLogger` is available anywhere except ``__init__`` in your
-:class:`~pytorch_lightning.core.lightning.LightningModule`.
-
-.. testcode::
-
- class MyModule(LightningModule):
- def any_lightning_module_function_or_hook(self):
- some_img = fake_image()
- self.logger.experiment.add_image('generated_images', some_img, 0)
-
-.. seealso::
- :class:`~pytorch_lightning.loggers.CometLogger` docs.
-
-----------------
-
-MLflow
-^^^^^^
-
-`MLflow `_ is a third-party logger.
-To use :class:`~pytorch_lightning.loggers.MLFlowLogger` as your logger do the following.
-First, install the package:
-
-.. code-block:: bash
-
- pip install mlflow
-
-Then configure the logger and pass it to the :class:`~pytorch_lightning.trainer.trainer.Trainer`:
-
-.. testcode::
-
- from pytorch_lightning.loggers import MLFlowLogger
- mlf_logger = MLFlowLogger(
- experiment_name="default",
- tracking_uri="file:./ml-runs"
- )
- trainer = Trainer(logger=mlf_logger)
-
-.. seealso::
- :class:`~pytorch_lightning.loggers.MLFlowLogger` docs.
-
-----------------
-
-Neptune.ai
-^^^^^^^^^^
-
-`Neptune.ai `_ is a third-party logger.
-To use :class:`~pytorch_lightning.loggers.NeptuneLogger` as your logger do the following.
-First, install the package:
-
-.. code-block:: bash
-
- pip install neptune-client
-
-Then configure the logger and pass it to the :class:`~pytorch_lightning.trainer.trainer.Trainer`:
-
-.. testcode::
-
- from pytorch_lightning.loggers import NeptuneLogger
-
- neptune_logger = NeptuneLogger(
- api_key='ANONYMOUS', # replace with your own
- project_name='shared/pytorch-lightning-integration',
- experiment_name='default', # Optional,
- params={'max_epochs': 10}, # Optional,
- tags=['pytorch-lightning', 'mlp'], # Optional,
- )
- trainer = Trainer(logger=neptune_logger)
-
-The :class:`~pytorch_lightning.loggers.NeptuneLogger` is available anywhere except ``__init__`` in your
-:class:`~pytorch_lightning.core.lightning.LightningModule`.
-
-.. testcode::
-
- class MyModule(LightningModule):
- def any_lightning_module_function_or_hook(self):
- some_img = fake_image()
- self.logger.experiment.add_image('generated_images', some_img, 0)
-
-.. seealso::
- :class:`~pytorch_lightning.loggers.NeptuneLogger` docs.
-
-----------------
-
-Tensorboard
-^^^^^^^^^^^
-
-To use `TensorBoard `_ as your logger do the following.
-
-.. testcode::
-
- from pytorch_lightning.loggers import TensorBoardLogger
- logger = TensorBoardLogger('tb_logs', name='my_model')
- trainer = Trainer(logger=logger)
-
-The :class:`~pytorch_lightning.loggers.TensorBoardLogger` is available anywhere except ``__init__`` in your
-:class:`~pytorch_lightning.core.lightning.LightningModule`.
-
-.. testcode::
-
- class MyModule(LightningModule):
- def any_lightning_module_function_or_hook(self):
- some_img = fake_image()
- self.logger.experiment.add_image('generated_images', some_img, 0)
-
-.. seealso::
- :class:`~pytorch_lightning.loggers.TensorBoardLogger` docs.
-
-----------------
-
-Test Tube
-^^^^^^^^^
-
-`Test Tube `_ is a
-`TensorBoard `_ logger but with nicer file structure.
-To use :class:`~pytorch_lightning.loggers.TestTubeLogger` as your logger do the following.
-First, install the package:
-
-.. code-block:: bash
-
- pip install test_tube
-
-Then configure the logger and pass it to the :class:`~pytorch_lightning.trainer.trainer.Trainer`:
-
-.. testcode::
-
- from pytorch_lightning.loggers import TestTubeLogger
- logger = TestTubeLogger('tb_logs', name='my_model')
- trainer = Trainer(logger=logger)
-
-The :class:`~pytorch_lightning.loggers.TestTubeLogger` is available anywhere except ``__init__`` in your
-:class:`~pytorch_lightning.core.lightning.LightningModule`.
-
-.. testcode::
-
- class MyModule(LightningModule):
- def any_lightning_module_function_or_hook(self):
- some_img = fake_image()
- self.logger.experiment.add_image('generated_images', some_img, 0)
-
-.. seealso::
- :class:`~pytorch_lightning.loggers.TestTubeLogger` docs.
-
-----------------
-
-Weights and Biases
-^^^^^^^^^^^^^^^^^^
-
-`Weights and Biases `_ is a third-party logger.
-To use :class:`~pytorch_lightning.loggers.WandbLogger` as your logger do the following.
-First, install the package:
-
-.. code-block:: bash
-
- pip install wandb
-
-Then configure the logger and pass it to the :class:`~pytorch_lightning.trainer.trainer.Trainer`:
-
-.. code-block:: python
-
- from pytorch_lightning.loggers import WandbLogger
- wandb_logger = WandbLogger(offline=True)
- trainer = Trainer(logger=wandb_logger)
-
-The :class:`~pytorch_lightning.loggers.WandbLogger` is available anywhere except ``__init__`` in your
-:class:`~pytorch_lightning.core.lightning.LightningModule`.
-
-.. testcode::
-
- class MyModule(LightningModule):
- def any_lightning_module_function_or_hook(self):
- some_img = fake_image()
- self.logger.experiment.log({
- "generated_images": [wandb.Image(some_img, caption="...")]
- })
-
-.. seealso::
- :class:`~pytorch_lightning.loggers.WandbLogger` docs.
-
-----------------
-
-Multiple Loggers
-^^^^^^^^^^^^^^^^
-
-Lightning supports the use of multiple loggers, just pass a list to the
-:class:`~pytorch_lightning.trainer.trainer.Trainer`.
-
-.. testcode::
-
- from pytorch_lightning.loggers import TensorBoardLogger, TestTubeLogger
- logger1 = TensorBoardLogger('tb_logs', name='my_model')
- logger2 = TestTubeLogger('tb_logs', name='my_model')
- trainer = Trainer(logger=[logger1, logger2])
-
-The loggers are available as a list anywhere except ``__init__`` in your
-:class:`~pytorch_lightning.core.lightning.LightningModule`.
-
-.. testcode::
-
- class MyModule(LightningModule):
- def any_lightning_module_function_or_hook(self):
- some_img = fake_image()
- # Option 1
- self.logger.experiment[0].add_image('generated_images', some_img, 0)
- # Option 2
- self.logger[0].experiment.add_image('generated_images', some_img, 0)
diff --git a/docs/source/experiment_reporting.rst b/docs/source/experiment_reporting.rst
deleted file mode 100644
index 4b6f0bb1efea4..0000000000000
--- a/docs/source/experiment_reporting.rst
+++ /dev/null
@@ -1,168 +0,0 @@
-.. testsetup:: *
-
- from pytorch_lightning.trainer.trainer import Trainer
-
-.. _experiment_reporting:
-
-Experiment Reporting
-=====================
-
-Lightning supports many different experiment loggers. These loggers allow you to monitor losses, images, text, etc...
-as training progresses. They usually provide a GUI to visualize and can sometimes even snapshot hyperparameters
-used in each experiment.
-
-----------
-
-Control logging frequency
-^^^^^^^^^^^^^^^^^^^^^^^^^
-
-It may slow training down to log every single batch. Trainer has an option to log every k batches instead.
-
-.. testcode::
-
- k = 10
- trainer = Trainer(row_log_interval=k)
-
-----------
-
-Control log writing frequency
-^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-
-Writing to a logger can be expensive. In Lightning you can set the interval at which you
-want to save logs to the filesystem using this trainer flag.
-
-.. testcode::
-
- k = 100
- trainer = Trainer(log_save_interval=k)
-
-Unlike the `row_log_interval`, this argument does not apply to all loggers.
-The example shown here works with :class:`~pytorch_lightning.loggers.tensorboard.TensorBoardLogger`,
-which is the default logger in Lightning.
-
-----------
-
-Log metrics
-^^^^^^^^^^^
-
-To plot metrics into whatever logger you passed in (tensorboard, comet, neptune, etc...)
-
-1. training_epoch_end, validation_epoch_end, test_epoch_end will all log anything in the "log" key of the return dict.
-
-.. testcode::
-
- def training_epoch_end(self, outputs):
- loss = some_loss()
- ...
-
- logs = {'train_loss': loss}
- results = {'log': logs}
- return results
-
- def validation_epoch_end(self, outputs):
- loss = some_loss()
- ...
-
- logs = {'val_loss': loss}
- results = {'log': logs}
- return results
-
- def test_epoch_end(self, outputs):
- loss = some_loss()
- ...
-
- logs = {'test_loss': loss}
- results = {'log': logs}
- return results
-
-2. In addition, you can also use any arbitrary functionality from a particular logger from within your LightningModule.
-For instance, here we log images using tensorboard.
-
-.. testcode::
- :skipif: not TORCHVISION_AVAILABLE
-
- def training_step(self, batch, batch_idx):
- self.generated_imgs = self.decoder.generate()
-
- sample_imgs = self.generated_imgs[:6]
- grid = torchvision.utils.make_grid(sample_imgs)
- self.logger.experiment.add_image('generated_images', grid, 0)
-
- ...
- return results
-
-----------
-
-Modify progress bar
-^^^^^^^^^^^^^^^^^^^
-
-Each return dict from the
-:meth:`~pytorch_lightning.core.lightning.LightningModule.training_step`,
-:meth:`~pytorch_lightning.core.lightning.LightningModule.training_epoch_end`,
-:meth:`~pytorch_lightning.core.lightning.LightningModule.validation_epoch_end` and
-:meth:`~pytorch_lightning.core.lightning.LightningModule.test_epoch_end`
-can also contain a key called `progress_bar`.
-
-Here we show the validation loss in the progress bar:
-
-.. testcode::
-
- def validation_epoch_end(self, outputs):
- loss = some_loss()
- ...
-
- logs = {'val_loss': loss}
- results = {'progress_bar': logs}
- return results
-
-The progress bar by default already includes the training loss and version number of the experiment
-if you are using a logger. These defaults can be customized by overriding the
-:meth:`~pytorch_lightning.core.lightning.LightningModule.get_progress_bar_dict` hook in your module.
-
-
-----------
-
-Configure console logging
-^^^^^^^^^^^^^^^^^^^^^^^^^
-
-Lightning logs useful information about the training process and user warnings to the console.
-You can retrieve the Lightning logger and change it to your liking. For example, increase the logging level
-to see fewer messages like so:
-
-.. code-block:: python
-
- import logging
- logging.getLogger("lightning").setLevel(logging.ERROR)
-
-Read more about custom Python logging `here `_.
-
-
-----------
-
-Snapshot hyperparameters
-^^^^^^^^^^^^^^^^^^^^^^^^
-
-When training a model, it's useful to know what hyperparams went into that model.
-When Lightning creates a checkpoint, it stores a key "hparams" with the hyperparams.
-
-.. code-block:: python
-
- lightning_checkpoint = torch.load(filepath, map_location=lambda storage, loc: storage)
- hyperparams = lightning_checkpoint['hparams']
-
-Some loggers also allow logging the hyperparams used in the experiment. For instance,
-when using the TestTubeLogger or the TensorBoardLogger, all hyperparams will show
-in the `hparams tab `_.
-
-----------
-
-Snapshot code
-^^^^^^^^^^^^^
-
-Loggers also allow you to snapshot a copy of the code used in this experiment.
-For example, TestTubeLogger does this with a flag:
-
-.. testcode::
-
- from pytorch_lightning.loggers import TestTubeLogger
- logger = TestTubeLogger('.', create_git_tag=True)
diff --git a/docs/source/index.rst b/docs/source/index.rst
index 03cebc03871a8..d9c99c88b1147 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -37,7 +37,7 @@ PyTorch Lightning Documentation
callbacks
datamodules
- loggers
+ logging
metrics
.. toctree::
@@ -87,8 +87,7 @@ PyTorch Lightning Documentation
slurm
child_modules
debugging
- experiment_logging
- experiment_reporting
+ loggers
early_stopping
fast_training
hooks
diff --git a/docs/source/loggers.rst b/docs/source/loggers.rst
index 93b2e1cdacc13..0fae0f88cab5a 100644
--- a/docs/source/loggers.rst
+++ b/docs/source/loggers.rst
@@ -1,211 +1,253 @@
.. testsetup:: *
- from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.trainer.trainer import Trainer
- from pytorch_lightning import loggers as pl_loggers
+ from pytorch_lightning.core.lightning import LightningModule
-.. role:: hidden
- :class: hidden-section
-
.. _loggers:
+*******
Loggers
-===========
-Lightning supports the most popular logging frameworks (TensorBoard, Comet, etc...).
-To use a logger, simply pass it into the :class:`~pytorch_lightning.trainer.trainer.Trainer`.
-Lightning uses TensorBoard by default.
+*******
-.. testcode::
+Lightning supports the most popular logging frameworks (TensorBoard, Comet, etc...). TensorBoard is used by default,
+but you can pass to the :class:`~pytorch_lightning.trainer.trainer.Trainer` any combintation of the following loggers.
+
+.. note::
+
+ All loggers log by default to `os.getcwd()`. To change the path without creating a logger set
+ `Trainer(default_root_dir='/your/path/to/save/checkpoints')`
+
+Read more about :ref:`logging` options.
+
+Comet.ml
+========
+
+`Comet.ml `_ is a third-party logger.
+To use :class:`~pytorch_lightning.loggers.CometLogger` as your logger do the following.
+First, install the package:
- from pytorch_lightning import loggers as pl_loggers
+.. code-block:: bash
- tb_logger = pl_loggers.TensorBoardLogger('logs/')
- trainer = Trainer(logger=tb_logger)
+ pip install comet-ml
-Choose from any of the others such as MLflow, Comet, Neptune, WandB, ...
+Then configure the logger and pass it to the :class:`~pytorch_lightning.trainer.trainer.Trainer`:
.. testcode::
- comet_logger = pl_loggers.CometLogger(save_dir='logs/')
+ import os
+ from pytorch_lightning.loggers import CometLogger
+ comet_logger = CometLogger(
+ api_key=os.environ.get('COMET_API_KEY'),
+ workspace=os.environ.get('COMET_WORKSPACE'), # Optional
+ save_dir='.', # Optional
+ project_name='default_project', # Optional
+ rest_api_key=os.environ.get('COMET_REST_API_KEY'), # Optional
+ experiment_name='default' # Optional
+ )
trainer = Trainer(logger=comet_logger)
-To use multiple loggers, simply pass in a ``list`` or ``tuple`` of loggers ...
+The :class:`~pytorch_lightning.loggers.CometLogger` is available anywhere except ``__init__`` in your
+:class:`~pytorch_lightning.core.lightning.LightningModule`.
.. testcode::
- tb_logger = pl_loggers.TensorBoardLogger('logs/')
- comet_logger = pl_loggers.CometLogger(save_dir='logs/')
- trainer = Trainer(logger=[tb_logger, comet_logger])
-
-.. note::
+ class MyModule(LightningModule):
+ def any_lightning_module_function_or_hook(self):
+ some_img = fake_image()
+ self.logger.experiment.add_image('generated_images', some_img, 0)
- All loggers log by default to `os.getcwd()`. To change the path without creating a logger set
- `Trainer(default_root_dir='/your/path/to/save/checkpoints')`
+.. seealso::
+ :class:`~pytorch_lightning.loggers.CometLogger` docs.
-----------
+----------------
-Logging from a LightningModule
-------------------------------
-Interact with loggers in two ways, automatically and/or manually.
+MLflow
+======
-Automatic logging
-^^^^^^^^^^^^^^^^^
-Use the :func:`~~pytorch_lightning.core.lightning.LightningModule.log` method to log from anywhere in a LightningModule.
+`MLflow `_ is a third-party logger.
+To use :class:`~pytorch_lightning.loggers.MLFlowLogger` as your logger do the following.
+First, install the package:
-.. code-block:: python
+.. code-block:: bash
- def training_step(self, batch, batch_idx):
- self.log('my_metric', x)
+ pip install mlflow
-The :func:`~~pytorch_lightning.core.lightning.LightningModule.log` method has a few options:
+Then configure the logger and pass it to the :class:`~pytorch_lightning.trainer.trainer.Trainer`:
-- on_step (logs the metric at that step in training)
-- on_epoch (automatically accumulates and logs at the end of the epoch)
-- prog_bar (logs to the progress bar)
-- logger (logs to the logger like Tensorboard)
+.. testcode::
-Depending on where log is called from, Lightning auto-determines the correct mode for you. But of course
-you can override the default behavior by manually setting the flags
+ from pytorch_lightning.loggers import MLFlowLogger
+ mlf_logger = MLFlowLogger(
+ experiment_name="default",
+ tracking_uri="file:./ml-runs"
+ )
+ trainer = Trainer(logger=mlf_logger)
-.. note:: Setting on_epoch=True will accumulate your logged values over the full training epoch.
+.. seealso::
+ :class:`~pytorch_lightning.loggers.MLFlowLogger` docs.
-.. code-block:: python
+----------------
- def training_step(self, batch, batch_idx):
- self.log('my_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
+Neptune.ai
+==========
-Once your training starts, you can view the logs by using your favorite logger or booting up the Tensorboard logs:
+`Neptune.ai `_ is a third-party logger.
+To use :class:`~pytorch_lightning.loggers.NeptuneLogger` as your logger do the following.
+First, install the package:
.. code-block:: bash
- tensorboard --logdir ./lightning_logs
+ pip install neptune-client
+Then configure the logger and pass it to the :class:`~pytorch_lightning.trainer.trainer.Trainer`:
-Manual logging
-^^^^^^^^^^^^^^
-For certain things like histograms, text, images, etc... you may need to use the logger object directly.
+.. testcode::
-.. code-block:: python
+ from pytorch_lightning.loggers import NeptuneLogger
+
+ neptune_logger = NeptuneLogger(
+ api_key='ANONYMOUS', # replace with your own
+ project_name='shared/pytorch-lightning-integration',
+ experiment_name='default', # Optional,
+ params={'max_epochs': 10}, # Optional,
+ tags=['pytorch-lightning', 'mlp'], # Optional,
+ )
+ trainer = Trainer(logger=neptune_logger)
- def training_step(...):
- ...
- # the logger you used (in this case tensorboard)
- tensorboard = self.logger.experiment
- tensorboard.add_histogram(...)
- tensorboard.add_figure(...)
+The :class:`~pytorch_lightning.loggers.NeptuneLogger` is available anywhere except ``__init__`` in your
+:class:`~pytorch_lightning.core.lightning.LightningModule`.
-----------
+.. testcode::
-Logging from a Callback
------------------------
-To log from a callback, the :func:`~~pytorch_lightning.core.lightning.LightningModule.log`
-method of the LightningModule.
+ class MyModule(LightningModule):
+ def any_lightning_module_function_or_hook(self):
+ some_img = fake_image()
+ self.logger.experiment.add_image('generated_images', some_img, 0)
-.. code-block:: python
+.. seealso::
+ :class:`~pytorch_lightning.loggers.NeptuneLogger` docs.
- class MyCallback(Callback):
+----------------
- def on_train_epoch_end(self, trainer, pl_module):
- pl_module.log('something', x)
+Tensorboard
+===========
-or access the logger object directly
+To use `TensorBoard `_ as your logger do the following.
-.. code-block:: python
+.. testcode::
+
+ from pytorch_lightning.loggers import TensorBoardLogger
+ logger = TensorBoardLogger('tb_logs', name='my_model')
+ trainer = Trainer(logger=logger)
+
+The :class:`~pytorch_lightning.loggers.TensorBoardLogger` is available anywhere except ``__init__`` in your
+:class:`~pytorch_lightning.core.lightning.LightningModule`.
+
+.. testcode::
+
+ class MyModule(LightningModule):
+ def any_lightning_module_function_or_hook(self):
+ some_img = fake_image()
+ self.logger.experiment.add_image('generated_images', some_img, 0)
+
+.. seealso::
+ :class:`~pytorch_lightning.loggers.TensorBoardLogger` docs.
- class MyCallback(Callback):
+----------------
- def on_train_epoch_end(self, trainer, pl_module):
- tensorboard = pl_module.logger.experiment
- tensorboard.add_histogram(...)
- tensorboard.add_figure(...)
+Test Tube
+=========
-----------
+`Test Tube `_ is a
+`TensorBoard `_ logger but with nicer file structure.
+To use :class:`~pytorch_lightning.loggers.TestTubeLogger` as your logger do the following.
+First, install the package:
-Make a Custom Logger
---------------------
+.. code-block:: bash
+
+ pip install test_tube
-You can implement your own logger by writing a class that inherits from
-:class:`LightningLoggerBase`. Use the :func:`~pytorch_lightning.loggers.base.rank_zero_only`
-decorator to make sure that only the first process in DDP training logs data.
+Then configure the logger and pass it to the :class:`~pytorch_lightning.trainer.trainer.Trainer`:
.. testcode::
- from pytorch_lightning.utilities import rank_zero_only
- from pytorch_lightning.loggers import LightningLoggerBase
+ from pytorch_lightning.loggers import TestTubeLogger
+ logger = TestTubeLogger('tb_logs', name='my_model')
+ trainer = Trainer(logger=logger)
+
+The :class:`~pytorch_lightning.loggers.TestTubeLogger` is available anywhere except ``__init__`` in your
+:class:`~pytorch_lightning.core.lightning.LightningModule`.
- class MyLogger(LightningLoggerBase):
+.. testcode::
- @rank_zero_only
- def log_hyperparams(self, params):
- # params is an argparse.Namespace
- # your code to record hyperparameters goes here
- pass
+ class MyModule(LightningModule):
+ def any_lightning_module_function_or_hook(self):
+ some_img = fake_image()
+ self.logger.experiment.add_image('generated_images', some_img, 0)
- @rank_zero_only
- def log_metrics(self, metrics, step):
- # metrics is a dictionary of metric names and values
- # your code to record metrics goes here
- pass
+.. seealso::
+ :class:`~pytorch_lightning.loggers.TestTubeLogger` docs.
- def save(self):
- # Optional. Any code necessary to save logger data goes here
- # If you implement this, remember to call `super().save()`
- # at the start of the method (important for aggregation of metrics)
- super().save()
+----------------
- @rank_zero_only
- def finalize(self, status):
- # Optional. Any code that needs to be run after training
- # finishes goes here
- pass
+Weights and Biases
+==================
-If you write a logger that may be useful to others, please send
-a pull request to add it to Lightning!
+`Weights and Biases `_ is a third-party logger.
+To use :class:`~pytorch_lightning.loggers.WandbLogger` as your logger do the following.
+First, install the package:
-----------
+.. code-block:: bash
-Supported Loggers
------------------
-The following are loggers we support
+ pip install wandb
-Comet
-^^^^^
+Then configure the logger and pass it to the :class:`~pytorch_lightning.trainer.trainer.Trainer`:
-.. autoclass:: pytorch_lightning.loggers.comet.CometLogger
- :noindex:
+.. code-block:: python
-CSVLogger
-^^^^^^^^^
+ from pytorch_lightning.loggers import WandbLogger
+ wandb_logger = WandbLogger(offline=True)
+ trainer = Trainer(logger=wandb_logger)
-.. autoclass:: pytorch_lightning.loggers.csv_logs.CSVLogger
- :noindex:
+The :class:`~pytorch_lightning.loggers.WandbLogger` is available anywhere except ``__init__`` in your
+:class:`~pytorch_lightning.core.lightning.LightningModule`.
-MLFlow
-^^^^^^
+.. testcode::
-.. autoclass:: pytorch_lightning.loggers.mlflow.MLFlowLogger
- :noindex:
+ class MyModule(LightningModule):
+ def any_lightning_module_function_or_hook(self):
+ some_img = fake_image()
+ self.logger.experiment.log({
+ "generated_images": [wandb.Image(some_img, caption="...")]
+ })
-Neptune
-^^^^^^^
+.. seealso::
+ :class:`~pytorch_lightning.loggers.WandbLogger` docs.
-.. autoclass:: pytorch_lightning.loggers.neptune.NeptuneLogger
- :noindex:
+----------------
-Tensorboard
-^^^^^^^^^^^^
+Multiple Loggers
+================
-.. autoclass:: pytorch_lightning.loggers.tensorboard.TensorBoardLogger
- :noindex:
+Lightning supports the use of multiple loggers, just pass a list to the
+:class:`~pytorch_lightning.trainer.trainer.Trainer`.
-Test-tube
-^^^^^^^^^
+.. testcode::
-.. autoclass:: pytorch_lightning.loggers.test_tube.TestTubeLogger
- :noindex:
+ from pytorch_lightning.loggers import TensorBoardLogger, TestTubeLogger
+ logger1 = TensorBoardLogger('tb_logs', name='my_model')
+ logger2 = TestTubeLogger('tb_logs', name='my_model')
+ trainer = Trainer(logger=[logger1, logger2])
+
+The loggers are available as a list anywhere except ``__init__`` in your
+:class:`~pytorch_lightning.core.lightning.LightningModule`.
-Weights and Biases
-^^^^^^^^^^^^^^^^^^
+.. testcode::
-.. autoclass:: pytorch_lightning.loggers.wandb.WandbLogger
- :noindex:
+ class MyModule(LightningModule):
+ def any_lightning_module_function_or_hook(self):
+ some_img = fake_image()
+ # Option 1
+ self.logger.experiment[0].add_image('generated_images', some_img, 0)
+ # Option 2
+ self.logger[0].experiment.add_image('generated_images', some_img, 0)
diff --git a/docs/source/logging.rst b/docs/source/logging.rst
new file mode 100644
index 0000000000000..d67b4dbfa45b0
--- /dev/null
+++ b/docs/source/logging.rst
@@ -0,0 +1,363 @@
+.. testsetup:: *
+
+ from pytorch_lightning.core.lightning import LightningModule
+ from pytorch_lightning.trainer.trainer import Trainer
+ from pytorch_lightning import loggers as pl_loggers
+
+.. role:: hidden
+ :class: hidden-section
+
+.. _logging:
+
+
+#######
+Logging
+#######
+
+Lightning supports the most popular logging frameworks (TensorBoard, Comet, etc...).
+To use a logger, simply pass it into the :class:`~pytorch_lightning.trainer.trainer.Trainer`.
+Lightning uses TensorBoard by default.
+
+.. testcode::
+
+ from pytorch_lightning import loggers as pl_loggers
+
+ tb_logger = pl_loggers.TensorBoardLogger('logs/')
+ trainer = Trainer(logger=tb_logger)
+
+Choose from any of the others such as MLflow, Comet, Neptune, WandB, ...
+
+.. testcode::
+
+ comet_logger = pl_loggers.CometLogger(save_dir='logs/')
+ trainer = Trainer(logger=comet_logger)
+
+To use multiple loggers, simply pass in a ``list`` or ``tuple`` of loggers ...
+
+.. testcode::
+
+ tb_logger = pl_loggers.TensorBoardLogger('logs/')
+ comet_logger = pl_loggers.CometLogger(save_dir='logs/')
+ trainer = Trainer(logger=[tb_logger, comet_logger])
+
+.. note::
+
+ By default, lightning logs every 50 steps. Use Trainer flags to :ref:`logging_frequency`.
+
+.. note::
+
+ All loggers log by default to `os.getcwd()`. To change the path without creating a logger set
+ `Trainer(default_root_dir='/your/path/to/save/checkpoints')`
+
+----------
+
+******************************
+Logging from a LightningModule
+******************************
+
+Lightning offers automatic log functionalities for logging scalars, or manual logging for anything else.
+
+Automatic logging
+=================
+Use the :func:`~~pytorch_lightning.core.lightning.LightningModule.log` method to log from anywhere in a :class:`~pytorch_lightning.core.LightningModule`.
+
+.. code-block:: python
+
+ def training_step(self, batch, batch_idx):
+ self.log('my_metric', x)
+
+Depending on where log is called from, Lightning auto-determines the correct logging mode for you.\
+But of course you can override the default behavior by manually setting the :func:`~~pytorch_lightning.core.lightning.LightningModule.log` parameters.
+
+.. code-block:: python
+
+ def training_step(self, batch, batch_idx):
+ self.log('my_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
+
+The :func:`~~pytorch_lightning.core.lightning.LightningModule.log` method has a few options:
+
+* on_step: Logs the metric at the current step. Defaults to True in :func:`~~pytorch_lightning.core.lightning.LightningModule.training_step`, and :func:`~pytorch_lightning.core.lightning.LightningModule.training_step_end`.
+
+* on_epoch: Automatically accumulates and logs at the end of the epoch. Defaults to True anywhere in validation or test loops, and in :func:`~~pytorch_lightning.core.lightning.LightningModule.training_epoch_end`.
+
+* prog_bar: Logs to the progress bar.
+
+* logger: Logs to the logger like Tensorboard, or any other custom logger passed to the :class:`~pytorch_lightning.trainer.trainer.Trainer`.
+
+
+.. note:: Setting on_epoch=True will accumulate your logged values over the full training epoch.
+
+
+Manual logging
+==============
+If you want to log anything that is not a scalar, like histograms, text, images, etc... you may need to use the logger object directly.
+
+.. code-block:: python
+
+ def training_step(...):
+ ...
+ # the logger you used (in this case tensorboard)
+ tensorboard = self.logger.experiment
+ tensorboard.add_image()
+ tensorboard.add_histogram(...)
+ tensorboard.add_figure(...)
+
+
+Access your logs
+================
+Once your training starts, you can view the logs by using your favorite logger or booting up the Tensorboard logs:
+
+.. code-block:: bash
+
+ tensorboard --logdir ./lightning_logs
+
+----------
+
+***********************
+Logging from a Callback
+***********************
+To log from a callback, use the :func:`~~pytorch_lightning.core.lightning.LightningModule.log`
+method of the :class:`~pytorch_lightning.core.LightningModule`.
+
+.. code-block:: python
+
+ class MyCallback(Callback):
+
+ def on_train_epoch_end(self, trainer, pl_module):
+ pl_module.log('something', x)
+
+or access the logger object directly for manual logging
+
+.. code-block:: python
+
+ class MyCallback(Callback):
+
+ def on_train_epoch_end(self, trainer, pl_module):
+ tensorboard = pl_module.logger.experiment
+ tensorboard.add_histogram(...)
+ tensorboard.add_figure(...)
+
+----------
+
+********************
+Make a custom logger
+********************
+
+You can implement your own logger by writing a class that inherits from
+:class:`LightningLoggerBase`. Use the :func:`~pytorch_lightning.loggers.base.rank_zero_only`
+decorator to make sure that only the first process in DDP training logs data.
+
+.. testcode::
+
+ from pytorch_lightning.utilities import rank_zero_only
+ from pytorch_lightning.loggers import LightningLoggerBase
+
+ class MyLogger(LightningLoggerBase):
+
+ def name(self):
+ return 'MyLogger'
+
+ def experiment(self):
+ # Return the experiment object associated with this logger.
+ pass
+
+ def version(self):
+ # Return the experiment version, int or str.
+ return '0.1'
+
+ @rank_zero_only
+ def log_hyperparams(self, params):
+ # params is an argparse.Namespace
+ # your code to record hyperparameters goes here
+ pass
+
+ @rank_zero_only
+ def log_metrics(self, metrics, step):
+ # metrics is a dictionary of metric names and values
+ # your code to record metrics goes here
+ pass
+
+ def save(self):
+ # Optional. Any code necessary to save logger data goes here
+ # If you implement this, remember to call `super().save()`
+ # at the start of the method (important for aggregation of metrics)
+ super().save()
+
+ @rank_zero_only
+ def finalize(self, status):
+ # Optional. Any code that needs to be run after training
+ # finishes goes here
+ pass
+
+If you write a logger that may be useful to others, please send
+a pull request to add it to Lightning!
+
+----------
+
+.. _logging_frequency:
+
+
+*************************
+Control logging frequency
+*************************
+
+Logging frequency
+=================
+
+It may slow training down to log every single batch. By default, Lightning logs every 50 rows, or 50 training steps.
+To change this behaviour, set the `log_every_n_steps` :class:`~pytorch_lightning.trainer.trainer.Trainer` flag.
+
+.. testcode::
+
+ k = 10
+ trainer = Trainer(log_every_n_steps=k)
+
+
+
+Log writing frequency
+=====================
+
+Writing to a logger can be expensive, so by default Lightning write logs to disc or to the given logger every 100 training steps.
+To change this behaviour, set the interval at which you wish to flush logs to the filesystem using `log_every_n_steps` :class:`~pytorch_lightning.trainer.trainer.Trainer` flag.
+
+.. testcode::
+
+ k = 100
+ trainer = Trainer(flush_logs_every_n_steps=k)
+
+Unlike the `log_every_n_steps`, this argument does not apply to all loggers.
+The example shown here works with :class:`~pytorch_lightning.loggers.tensorboard.TensorBoardLogger`,
+which is the default logger in Lightning.
+
+----------
+
+************
+Progress Bar
+************
+You can add any metric to the progress bar using :func:`~~pytorch_lightning.core.lightning.LightningModule.log`
+method, setting `prog_bar=True`.
+
+
+.. code-block:: python
+
+ def training_step(self, batch, batch_idx):
+ self.log('my_loss', loss, prog_bar=True)
+
+
+Modifying the progress bar
+==========================
+
+The progress bar by default already includes the training loss and version number of the experiment
+if you are using a logger. These defaults can be customized by overriding the
+:func:`~pytorch_lightning.core.lightning.LightningModule.get_progress_bar_dict` hook in your module.
+
+.. code-block:: python
+
+ def get_progress_bar_dict(self):
+ # don't show the version number
+ items = super().get_progress_bar_dict()
+ items.pop("v_num", None)
+ return items
+
+
+----------
+
+
+*************************
+Configure console logging
+*************************
+
+Lightning logs useful information about the training process and user warnings to the console.
+You can retrieve the Lightning logger and change it to your liking. For example, increase the logging level
+to see fewer messages like so:
+
+.. code-block:: python
+
+ import logging
+ logging.getLogger("lightning").setLevel(logging.ERROR)
+
+Read more about custom Python logging `here `_.
+
+
+----------
+
+***********************
+Logging hyperparameters
+***********************
+
+When training a model, it's useful to know what hyperparams went into that model.
+When Lightning creates a checkpoint, it stores a key "hparams" with the hyperparams.
+
+.. code-block:: python
+
+ lightning_checkpoint = torch.load(filepath, map_location=lambda storage, loc: storage)
+ hyperparams = lightning_checkpoint['hparams']
+
+Some loggers also allow logging the hyperparams used in the experiment. For instance,
+when using the TestTubeLogger or the TensorBoardLogger, all hyperparams will show
+in the `hparams tab `_.
+
+----------
+
+*************
+Snapshot code
+*************
+
+Loggers also allow you to snapshot a copy of the code used in this experiment.
+For example, TestTubeLogger does this with a flag:
+
+.. testcode::
+
+ from pytorch_lightning.loggers import TestTubeLogger
+ logger = TestTubeLogger('.', create_git_tag=True)
+
+----------
+
+*****************
+Supported Loggers
+*****************
+
+The following are loggers we support
+
+Comet
+=====
+
+.. autoclass:: pytorch_lightning.loggers.comet.CometLogger
+ :noindex:
+
+CSVLogger
+=========
+
+.. autoclass:: pytorch_lightning.loggers.csv_logs.CSVLogger
+ :noindex:
+
+MLFlow
+======
+
+.. autoclass:: pytorch_lightning.loggers.mlflow.MLFlowLogger
+ :noindex:
+
+Neptune
+=======
+
+.. autoclass:: pytorch_lightning.loggers.neptune.NeptuneLogger
+ :noindex:
+
+Tensorboard
+============
+
+.. autoclass:: pytorch_lightning.loggers.tensorboard.TensorBoardLogger
+ :noindex:
+
+Test-tube
+=========
+
+.. autoclass:: pytorch_lightning.loggers.test_tube.TestTubeLogger
+ :noindex:
+
+Weights and Biases
+==================
+
+.. autoclass:: pytorch_lightning.loggers.wandb.WandbLogger
+ :noindex:
+
diff --git a/docs/source/lr_finder.rst b/docs/source/lr_finder.rst
index 83243082d9572..28988177d8251 100755
--- a/docs/source/lr_finder.rst
+++ b/docs/source/lr_finder.rst
@@ -29,20 +29,13 @@ initial lr.
Using Lightning's built-in LR finder
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-In the most basic use case, this feature can be enabled during trainer construction
-with ``Trainer(auto_lr_find=True)``. When ``.fit(model)`` is called, the LR finder
-will automatically run before any training is done. The ``lr`` that is found
-and used will be written to the console and logged together with all other
-hyperparameters of the model.
-
-.. testcode::
-
- # default: no automatic learning rate finder
- trainer = Trainer(auto_lr_find=False)
-
-This flag sets your learning rate which can be accessed via ``self.lr`` or ``self.learning_rate``.
+To enable the learning rate finder, your :class:`~pytorch_lightning.core.LightningModule` needs to have a ``learning_rate`` or ``lr`` property.
+Then, set ``Trainer(auto_lr_find=True)`` during trainer construction,
+and then call ``trainer.tune(model)`` to run the LR finder. The suggested ``learning_rate``
+will be written to the console and will be automatically set to your :class:`~pytorch_lightning.core.LightningModule`,
+which can be accessed via ``self.learning_rate`` or ``self.lr``.
-.. testcode::
+.. code-block:: python
class LitModel(LightningModule):
@@ -51,31 +44,30 @@ This flag sets your learning rate which can be accessed via ``self.lr`` or ``sel
def configure_optimizers(self):
return Adam(self.parameters(), lr=(self.lr or self.learning_rate))
+
+ model = LitModel()
# finds learning rate automatically
- # sets hparams.lr or hparams.learning_rate to that learning rate
+ # sets self.lr or self.learning_rate to that learning rate
trainer = Trainer(auto_lr_find=True)
-To use an arbitrary value set it as auto_lr_find
+ trainer.tune(model)
+
+If your model is using an arbitrary value instead of ``self.lr`` or ``self.learning_rate``, set that value as auto_lr_find
-.. testcode::
+.. code-block:: python
+
+ model = LitModel()
# to set to your own hparams.my_value
trainer = Trainer(auto_lr_find='my_value')
-Under the hood, when you call fit it runs the learning rate finder before actually calling fit.
+ trainer.tune(model)
-.. code-block:: python
-
- # when you call .fit() this happens
- # 1. find learning rate
- # 2. actually run fit
- trainer.fit(model)
-If you want to inspect the results of the learning rate finder before doing any
-actual training or just play around with the parameters of the algorithm, this
-can be done by invoking the ``lr_find`` method of the trainer. A typical example
-of this would look like
+If you want to inspect the results of the learning rate finder or just play around
+with the parameters of the algorithm, this can be done by invoking the ``lr_find``
+method of the trainer. A typical example of this would look like
.. code-block:: python
diff --git a/docs/source/training_tricks.rst b/docs/source/training_tricks.rst
index b0d9d2654c354..0de18a1b7f16c 100644
--- a/docs/source/training_tricks.rst
+++ b/docs/source/training_tricks.rst
@@ -54,34 +54,38 @@ longer training time. Inspired by https://github.com/BlackHC/toma.
# DEFAULT (ie: don't scale batch size automatically)
trainer = Trainer(auto_scale_batch_size=None)
- # Autoscale batch size
+ # Autoscale batch size
trainer = Trainer(auto_scale_batch_size=None|'power'|'binsearch')
# find the batch size
trainer.tune(model)
Currently, this feature supports two modes `'power'` scaling and `'binsearch'`
-scaling. In `'power'` scaling, starting from a batch size of 1 keeps doubling
-the batch size until an out-of-memory (OOM) error is encountered. Setting the
-argument to `'binsearch'` continues to finetune the batch size by performing
-a binary search.
+scaling. In `'power'` scaling, starting from a batch size of 1 keeps doubling
+the batch size until an out-of-memory (OOM) error is encountered. Setting the
+argument to `'binsearch'` will initially also try doubling the batch size until
+it encounters an OOM, after which it will do a binary search that will finetune the
+batch size. Additionally, it should be noted that the batch size scaler cannot
+search for batch sizes larger than the size of the training dataset.
-.. note::
- This feature expects that a `batch_size` field in the `hparams` of your model, i.e.,
- `model.hparams.batch_size` should exist and will be overridden by the results of this
- algorithm. Additionally, your `train_dataloader()` method should depend on this field
+.. note::
+
+ This feature expects that a `batch_size` field is either located as a model attribute
+ i.e. `model.batch_size` or as a field in your `hparams` i.e. `model.hparams.batch_size`.
+ The field should exist and will be overridden by the results of this algorithm.
+ Additionally, your `train_dataloader()` method should depend on this field
for this feature to work i.e.
.. code-block:: python
-
+
def train_dataloader(self):
- return DataLoader(train_dataset, batch_size=self.batch_size)
+ return DataLoader(train_dataset, batch_size=self.batch_size|self.hparams.batch_size)
.. warning::
-
+
Due to these constraints, this features does *NOT* work when passing dataloaders directly
- to `.fit()`.
+ to `.fit()`.
The scaling algorithm has a number of parameters that the user can control by
invoking the trainer method `.scale_batch_size` themself (see description below).
@@ -93,29 +97,29 @@ invoking the trainer method `.scale_batch_size` themself (see description below)
tuner = Tuner(trainer)
# Invoke method
- new_batch_size = tuner.scale_batch_size(model, ...)
+ new_batch_size = tuner.scale_batch_size(model, *extra_parameters_here)
# Override old batch size
model.hparams.batch_size = new_batch_size
-
+
# Fit as normal
trainer.fit(model)
The algorithm in short works by:
1. Dumping the current state of the model and trainer
2. Iteratively until convergence or maximum number of tries `max_trials` (default 25) has been reached:
- - Call `fit()` method of trainer. This evaluates `steps_per_trial` (default 3) number of
- training steps. Each training step can trigger an OOM error if the tensors
- (training batch, weights, gradients ect.) allocated during the steps have a
+ - Call `fit()` method of trainer. This evaluates `steps_per_trial` (default 3) number of
+ training steps. Each training step can trigger an OOM error if the tensors
+ (training batch, weights, gradients ect.) allocated during the steps have a
too large memory footprint.
- If an OOM error is encountered, decrease batch size else increase it.
How much the batch size is increased/decreased is determined by the choosen
stratrgy.
- 3. The found batch size is saved to `model.hparams.batch_size`
+ 3. The found batch size is saved to either `model.batch_size` or `model.hparams.batch_size`
4. Restore the initial state of model and trainer
-.. autoclass:: pytorch_lightning.trainer.training_tricks.TrainerTrainingTricksMixin
- :members: scale_batch_size
+.. autoclass:: pytorch_lightning.tuner.tuning.Tuner
:noindex:
+ :members: scale_batch_size
.. warning:: Batch size finder is not supported for DDP yet, it is coming soon.
diff --git a/pl_examples/bug_report_model.py b/pl_examples/bug_report_model.py
new file mode 100644
index 0000000000000..27ecf774623af
--- /dev/null
+++ b/pl_examples/bug_report_model.py
@@ -0,0 +1,133 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# --------------------------------------------
+# --------------------------------------------
+# --------------------------------------------
+# USE THIS MODEL TO REPRODUCE A BUG YOU REPORT
+# --------------------------------------------
+# --------------------------------------------
+# --------------------------------------------
+import os
+import torch
+from torch.utils.data import Dataset
+from pytorch_lightning import Trainer, LightningModule
+
+
+class RandomDataset(Dataset):
+ def __init__(self, size, length):
+ self.len = length
+ self.data = torch.randn(length, size)
+
+ def __getitem__(self, index):
+ return self.data[index]
+
+ def __len__(self):
+ return self.len
+
+
+class BoringModel(LightningModule):
+
+ def __init__(self):
+ """
+ Testing PL Module
+
+ Use as follows:
+ - subclass
+ - modify the behavior for what you want
+
+ class TestModel(BaseTestModel):
+ def training_step(...):
+ # do your own thing
+
+ or:
+
+ model = BaseTestModel()
+ model.training_epoch_end = None
+
+ """
+ super().__init__()
+ self.layer = torch.nn.Linear(32, 2)
+
+ def forward(self, x):
+ return self.layer(x)
+
+ def loss(self, batch, prediction):
+ # An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls
+ return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction))
+
+ def step(self, x):
+ x = self.layer(x)
+ out = torch.nn.functional.mse_loss(x, torch.ones_like(x))
+ return out
+
+ def training_step(self, batch, batch_idx):
+ output = self.layer(batch)
+ loss = self.loss(batch, output)
+ return {"loss": loss}
+
+ def training_step_end(self, training_step_outputs):
+ return training_step_outputs
+
+ def training_epoch_end(self, outputs) -> None:
+ torch.stack([x["loss"] for x in outputs]).mean()
+
+ def validation_step(self, batch, batch_idx):
+ output = self.layer(batch)
+ loss = self.loss(batch, output)
+ return {"x": loss}
+
+ def validation_epoch_end(self, outputs) -> None:
+ torch.stack([x['x'] for x in outputs]).mean()
+
+ def test_step(self, batch, batch_idx):
+ output = self.layer(batch)
+ loss = self.loss(batch, output)
+ return {"y": loss}
+
+ def test_epoch_end(self, outputs) -> None:
+ torch.stack([x["y"] for x in outputs]).mean()
+
+ def configure_optimizers(self):
+ optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
+ lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
+ return [optimizer], [lr_scheduler]
+
+
+def run_test():
+ class TestModel(BoringModel):
+
+ def on_train_epoch_start(self) -> None:
+ print('override any method to prove your bug')
+
+ # fake data
+ train_data = torch.utils.data.DataLoader(RandomDataset(32, 64))
+ val_data = torch.utils.data.DataLoader(RandomDataset(32, 64))
+ test_data = torch.utils.data.DataLoader(RandomDataset(32, 64))
+
+ # model
+ model = TestModel()
+ trainer = Trainer(
+ default_root_dir=os.getcwd(),
+ limit_train_batches=1,
+ limit_val_batches=1,
+ max_epochs=1,
+ weights_summary=None,
+ )
+ trainer.fit(model, train_data, val_data)
+ trainer.test(test_dataloaders=test_data)
+
+
+if __name__ == '__main__':
+ run_test()
diff --git a/pytorch_lightning/accelerators/ddp_backend.py b/pytorch_lightning/accelerators/ddp_backend.py
index 79aad9aef3589..cb05a3838aaea 100644
--- a/pytorch_lightning/accelerators/ddp_backend.py
+++ b/pytorch_lightning/accelerators/ddp_backend.py
@@ -116,7 +116,9 @@ def _call_children_scripts(self):
env_copy = os.environ.copy()
env_copy['LOCAL_RANK'] = f'{local_rank}'
env_copy['PL_DDP_PID'] = str(self.trainer.data_parallel_device_ids[local_rank])
- env_copy['PL_GLOBAL_SEED'] = os.environ.get('PL_GLOBAL_SEED')
+ # remove env var if global seed not set
+ if os.environ.get('PL_GLOBAL_SEED') is None:
+ del env_copy['PL_GLOBAL_SEED']
# start process
# if hydra is available and initialized, make sure to set the cwd correctly
diff --git a/pytorch_lightning/accelerators/tpu_backend.py b/pytorch_lightning/accelerators/tpu_backend.py
index 7dc437978fd3b..a54cbb7f1ac1e 100644
--- a/pytorch_lightning/accelerators/tpu_backend.py
+++ b/pytorch_lightning/accelerators/tpu_backend.py
@@ -21,20 +21,19 @@
from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.base_backend import Accelerator
from pytorch_lightning.core import LightningModule
-from pytorch_lightning.distributed import LightningDistributed
-from pytorch_lightning.utilities import AMPType, rank_zero_info, rank_zero_only, rank_zero_warn
+from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.exceptions import MisconfigurationException
+from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils
-try:
+TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists()
+
+if TPU_AVAILABLE:
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as xla_pl
import torch_xla.distributed.xla_multiprocessing as xmp
-except ImportError:
- XLA_AVAILABLE = False
-else:
- XLA_AVAILABLE = True
+ import torch_xla.distributed.parallel_loader as xla_pl
class TPUBackend(Accelerator):
@@ -47,7 +46,8 @@ def __init__(self, trainer, cluster_environment=None):
def setup(self, model):
rank_zero_info(f'training on {self.trainer.tpu_cores} TPU cores')
- if not XLA_AVAILABLE:
+ # TODO: Move this check to Trainer __init__ or device parser
+ if not TPU_AVAILABLE:
raise MisconfigurationException('PyTorch XLA not installed.')
# see: https://discuss.pytorch.org/t/segfault-with-multiprocessing-queue/81292/2
@@ -171,7 +171,7 @@ def to_device(self, batch):
See Also:
- :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device`
"""
- if not XLA_AVAILABLE:
+ if not TPU_AVAILABLE:
raise MisconfigurationException(
'Requested to transfer batch to TPU but XLA is not available.'
' Are you sure this machine has TPUs?'
diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py
index 3177c9300efb3..58fc9ec7816d6 100644
--- a/pytorch_lightning/callbacks/early_stopping.py
+++ b/pytorch_lightning/callbacks/early_stopping.py
@@ -19,23 +19,22 @@
Monitor a validation metric and stop training when it stops improving.
"""
+import os
+
import numpy as np
import torch
from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_warn
-import os
+from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils
+
+TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists()
+
torch_inf = torch.tensor(np.Inf)
-try:
- import torch_xla
- import torch_xla.core.xla_model as xm
-except ImportError:
- XLA_AVAILABLE = False
-else:
- XLA_AVAILABLE = True
+
class EarlyStopping(Callback):
@@ -186,7 +185,7 @@ def _run_early_stopping_check(self, trainer, pl_module):
if not isinstance(current, torch.Tensor):
current = torch.tensor(current, device=pl_module.device)
- if trainer.use_tpu and XLA_AVAILABLE:
+ if trainer.use_tpu and TPU_AVAILABLE:
current = current.cpu()
if self.monitor_op(current - self.min_delta, self.best_score):
@@ -206,6 +205,7 @@ def _run_early_stopping_check(self, trainer, pl_module):
def on_train_end(self, trainer, pl_module):
if self.stopped_epoch > 0 and self.verbose > 0:
+ # todo: remove this old warning
rank_zero_warn('Displayed epoch numbers by `EarlyStopping` start from "1" until v0.6.x,'
' but will start from "0" in v0.8.0.', DeprecationWarning)
log.info(f'Epoch {self.stopped_epoch + 1:05d}: early stopping triggered.')
diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py
index 6a540bba7022e..e1f82968c7990 100644
--- a/pytorch_lightning/core/lightning.py
+++ b/pytorch_lightning/core/lightning.py
@@ -30,6 +30,8 @@
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
+from pytorch_lightning.core.step_result import TrainResult, EvalResult
+from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.utilities.parsing import (
@@ -43,12 +45,10 @@
from torch.optim.optimizer import Optimizer
-try:
+TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists()
+
+if TPU_AVAILABLE:
import torch_xla.core.xla_model as xm
-except ImportError:
- XLA_AVAILABLE = False
-else:
- XLA_AVAILABLE = True
class LightningModule(
diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py
index 14ff996aa3949..756fffc60a0b1 100644
--- a/pytorch_lightning/core/step_result.py
+++ b/pytorch_lightning/core/step_result.py
@@ -21,6 +21,7 @@
import os
from pytorch_lightning.metrics.converters import sync_ddp_if_available
+from typing import Iterable
class Result(Dict):
@@ -217,7 +218,12 @@ def __set_meta(
_internal = self['meta']['_internal']
_internal['_reduce_on_epoch'] = max(_internal['_reduce_on_epoch'], on_epoch)
- def track_batch_size(self, batch_size):
+ def track_batch_size(self, batch):
+ try:
+ batch_size = self.unpack_batch_size(batch)
+ except RecursionError as re:
+ batch_size = 1
+
meta = self['meta']
meta['_internal']['batch_sizes'].append(batch_size)
@@ -321,6 +327,25 @@ def __copy__(self):
newone[k] = copy(v)
return newone
+ def unpack_batch_size(self, sample):
+ """
+ Recursively unpack sample to find a torch.Tensor.
+ returns len(tensor) when found, or 1 when it hits an empty or non iterable.
+ """
+ if isinstance(sample, torch.Tensor):
+ size = sample.size(0)
+ elif isinstance(sample, str):
+ return len(sample)
+ elif isinstance(sample, dict):
+ sample = next(iter(sample.values()), 1)
+ size = self.unpack_batch_size(sample)
+ elif isinstance(sample, Iterable):
+ sample = next(iter(sample), 1)
+ size = self.unpack_batch_size(sample)
+ else:
+ size = 1
+ return size
+
@classmethod
def gather(cls, outputs):
meta = outputs[0].get('meta')
@@ -387,7 +412,10 @@ def reduce_on_epoch_end(cls, outputs):
if option['on_epoch']:
fx = option['reduce_fx']
if fx == torch.mean:
- reduced_val = weighted_mean(result[k], batch_sizes)
+ try:
+ reduced_val = weighted_mean(result[k], batch_sizes)
+ except Exception as e:
+ reduced_val = torch.mean(result[k])
else:
reduced_val = fx(result[k])
@@ -420,7 +448,12 @@ def reduce_across_time(cls, time_outputs):
tbptt_reduce_fx = torch.mean
else:
tbptt_reduce_fx = meta[k]['tbptt_reduce_fx']
- result[k] = tbptt_reduce_fx(value.float())
+
+ if isinstance(value, dict):
+ # TODO: recursive reduce:
+ _recursive_fx_apply(value, tbptt_reduce_fx)
+ else:
+ result[k] = tbptt_reduce_fx(value.float())
result['meta'] = meta
return result
@@ -466,12 +499,14 @@ def recursive_gather(outputs: Sequence[dict], result: Optional[MutableMapping] =
for k, v in out.items():
if isinstance(v, dict):
- v = recursive_gather([v], result)
-
- if k not in result:
- result[k] = []
+ in_d = result.get(k, {})
+ v = recursive_gather([v], in_d)
+ result[k] = v
+ else:
+ if k not in result:
+ result[k] = []
- result[k].append(v)
+ result[k].append(v)
return result
@@ -484,6 +519,18 @@ def recursive_stack(result: MutableMapping):
result[k] = collate_tensors(v)
+def _recursive_fx_apply(input: dict, fx):
+ for k, v in input.items():
+ if isinstance(v, list):
+ v = torch.tensor(v)
+
+ if isinstance(v, torch.Tensor):
+ v = fx(v.float())
+ input[k] = v
+ else:
+ _recursive_fx_apply(v, fx)
+
+
def collate_tensors(items: Union[List, Tuple]) -> Union[Tensor, List, Tuple]:
if not items or not isinstance(items, (list, tuple)) or any(not isinstance(item, Tensor) for item in items):
# items is not a sequence, empty, or contains non-tensors
@@ -894,9 +941,41 @@ def write_dict(self, predictions_dict, filename='predictions.pt'):
def weighted_mean(result, weights):
- if not isinstance(result, torch.Tensor):
- result = torch.tensor(result)
- weights = weights.to(result.device)[:result.size(0)]
- numerator = torch.dot(result.float(), weights.transpose(-1, 0).float())
- result = numerator / weights.sum().float()
+
+ if isinstance(result, dict):
+ _process_dataloader_aggregated_steps(result, weights)
+ else:
+ if isinstance(result, list):
+ result = torch.tensor(result)
+
+ weights = weights.to(result.device)[:result.size(0)]
+ numerator = torch.dot(result.float(), weights.transpose(-1, 0).float())
+ result = numerator / weights.sum().float()
return result
+
+
+def _process_dataloader_aggregated_steps(result, weights):
+ internal_keys = {'meta'}
+
+ moved = False
+
+ for k, v in result.items():
+ if k in internal_keys:
+ continue
+
+ # make sure v is a tensor
+ if not isinstance(v, torch.Tensor):
+ v = torch.tensor(v)
+
+ # move to memory only once
+ if not moved:
+ weights = weights.to(v.device)
+ moved = True
+
+ # move weights to same device as value to reduce
+ weights_t = weights[:v.size(0)]
+
+ # weighted mean
+ numerator = torch.dot(v.float(), weights_t.transpose(-1, 0).float())
+ v = numerator / weights.sum().float()
+ result[k] = v
diff --git a/pytorch_lightning/loggers/comet.py b/pytorch_lightning/loggers/comet.py
index 37a91bd98cdac..56a7b77dfac2b 100644
--- a/pytorch_lightning/loggers/comet.py
+++ b/pytorch_lightning/loggers/comet.py
@@ -22,28 +22,25 @@
from typing import Any, Dict, Optional, Union
try:
- from comet_ml import BaseExperiment as CometBaseExperiment
+ import comet_ml
+
+except ModuleNotFoundError: # pragma: no-cover
+ comet_ml = None
+ CometExperiment = None
+ CometExistingExperiment = None
+ CometOfflineExperiment = None
+ API = None
+ generate_guid = None
+else:
from comet_ml import ExistingExperiment as CometExistingExperiment
from comet_ml import Experiment as CometExperiment
from comet_ml import OfflineExperiment as CometOfflineExperiment
- from comet_ml import generate_guid
try:
from comet_ml.api import API
except ImportError: # pragma: no-cover
# For more information, see: https://www.comet.ml/docs/python-sdk/releases/#release-300
from comet_ml.papi import API # pragma: no-cover
- from comet_ml.config import get_api_key, get_config
-except ImportError: # pragma: no-cover
- CometExperiment = None
- CometExistingExperiment = None
- CometOfflineExperiment = None
- CometBaseExperiment = None
- API = None
- generate_guid = None
- _COMET_AVAILABLE = False
-else:
- _COMET_AVAILABLE = True
import torch
from torch import is_tensor
@@ -117,17 +114,17 @@ class CometLogger(LightningLoggerBase):
"""
def __init__(
- self,
- api_key: Optional[str] = None,
- save_dir: Optional[str] = None,
- project_name: Optional[str] = None,
- rest_api_key: Optional[str] = None,
- experiment_name: Optional[str] = None,
- experiment_key: Optional[str] = None,
- offline: bool = False,
- **kwargs
+ self,
+ api_key: Optional[str] = None,
+ save_dir: Optional[str] = None,
+ project_name: Optional[str] = None,
+ rest_api_key: Optional[str] = None,
+ experiment_name: Optional[str] = None,
+ experiment_key: Optional[str] = None,
+ offline: bool = False,
+ **kwargs
):
- if not _COMET_AVAILABLE:
+ if comet_ml is None:
raise ImportError(
"You want to use `comet_ml` logger which is not installed yet,"
" install it with `pip install comet-ml`."
@@ -136,7 +133,7 @@ def __init__(
self._experiment = None
# Determine online or offline mode based on which arguments were passed to CometLogger
- api_key = api_key or get_api_key(None, get_config())
+ api_key = api_key or comet_ml.config.get_api_key(None, comet_ml.config.get_config())
if api_key is not None and save_dir is not None:
self.mode = "offline" if offline else "online"
@@ -173,7 +170,7 @@ def __init__(
@property
@rank_zero_experiment
- def experiment(self) -> CometBaseExperiment:
+ def experiment(self):
r"""
Actual Comet object. To use Comet features in your
:class:`~pytorch_lightning.core.lightning.LightningModule` do the following.
@@ -236,7 +233,6 @@ def log_metrics(self, metrics: Dict[str, Union[torch.Tensor, float]], step: Opti
metrics_without_epoch = metrics.copy()
epoch = metrics_without_epoch.pop('epoch', None)
-
self.experiment.log_metrics(metrics_without_epoch, step=step, epoch=epoch)
def reset_experiment(self):
@@ -284,7 +280,7 @@ def version(self) -> str:
return self._future_experiment_key
# Pre-generate an experiment key
- self._future_experiment_key = generate_guid()
+ self._future_experiment_key = comet_ml.generate_guid()
return self._future_experiment_key
diff --git a/pytorch_lightning/loggers/mlflow.py b/pytorch_lightning/loggers/mlflow.py
index 5433ef907990d..cd4f5c8073503 100644
--- a/pytorch_lightning/loggers/mlflow.py
+++ b/pytorch_lightning/loggers/mlflow.py
@@ -23,11 +23,9 @@
try:
import mlflow
from mlflow.tracking import MlflowClient
- _MLFLOW_AVAILABLE = True
except ModuleNotFoundError: # pragma: no-cover
mlflow = None
MlflowClient = None
- _MLFLOW_AVAILABLE = False
from pytorch_lightning import _logger as log
@@ -83,7 +81,7 @@ def __init__(
tags: Optional[Dict[str, Any]] = None,
save_dir: Optional[str] = './mlruns'
):
- if not _MLFLOW_AVAILABLE:
+ if mlflow is None:
raise ImportError('You want to use `mlflow` logger which is not installed yet,'
' install it with `pip install mlflow`.')
super().__init__()
diff --git a/pytorch_lightning/loggers/tensorboard.py b/pytorch_lightning/loggers/tensorboard.py
index 177e314ebf7b7..2e868452637bd 100644
--- a/pytorch_lightning/loggers/tensorboard.py
+++ b/pytorch_lightning/loggers/tensorboard.py
@@ -32,6 +32,7 @@
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities import rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import get_filesystem
+from pytorch_lightning.utilities.exceptions import MisconfigurationException
try:
from omegaconf import Container, OmegaConf
@@ -179,7 +180,15 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
for k, v in metrics.items():
if isinstance(v, torch.Tensor):
v = v.item()
- self.experiment.add_scalar(k, v, step)
+
+ if isinstance(v, dict):
+ self.experiment.add_scalars(k, v, step)
+ else:
+ try:
+ self.experiment.add_scalar(k, v, step)
+ except Exception as e:
+ m = f'you tried to log {v} which is not currently supported. Try a dict or a scalar/tensor.'
+ raise MisconfigurationException(m)
@rank_zero_only
def log_graph(self, model: LightningModule, input_array=None):
diff --git a/pytorch_lightning/trainer/__init__.py b/pytorch_lightning/trainer/__init__.py
index fd25ca129daf8..3c40896cb8ae5 100644
--- a/pytorch_lightning/trainer/__init__.py
+++ b/pytorch_lightning/trainer/__init__.py
@@ -242,24 +242,29 @@ def forward(self, x):
auto_lr_find
^^^^^^^^^^^^
Runs a learning rate finder algorithm (see this `paper `_)
-before any training, to find optimal initial learning rate.
+when calling trainer.tune(), to find optimal initial learning rate.
.. code-block:: python
# default used by the Trainer (no learning rate finder)
trainer = Trainer(auto_lr_find=False)
- # call tune to find the lr
- trainer.tune(model)
-
Example::
# run learning rate finder, results override hparams.learning_rate
trainer = Trainer(auto_lr_find=True)
+ # call tune to find the lr
+ trainer.tune(model)
+
+Example::
+
# run learning rate finder, results override hparams.my_lr_arg
trainer = Trainer(auto_lr_find='my_lr_arg')
+ # call tune to find the lr
+ trainer.tune(model)
+
.. note::
See the :ref:`learning rate finder guide `.
@@ -604,18 +609,18 @@ def world_size(self):
.. note:: Might slow performance because it uses the output of nvidia-smi.
-log_save_interval
-^^^^^^^^^^^^^^^^^
+flush_logs_every_n_steps
+^^^^^^^^^^^^^^^^^^^^^^^^
Writes logs to disk this often.
.. testcode::
# default used by the Trainer
- trainer = Trainer(log_save_interval=100)
+ trainer = Trainer(flush_logs_every_n_steps=100)
See Also:
- - :ref:`Experiment Reporting `
+ - :ref:`Logging `
logger
^^^^^^
@@ -939,18 +944,18 @@ def world_size(self):
# resume from a specific checkpoint
trainer = Trainer(resume_from_checkpoint='some/path/to/my_checkpoint.ckpt')
-row_log_interval
-^^^^^^^^^^^^^^^^
+log_every_n_steps
+^^^^^^^^^^^^^^^^^
How often to add logging rows (does not write to disk)
.. testcode::
# default used by the Trainer
- trainer = Trainer(row_log_interval=50)
+ trainer = Trainer(log_every_n_steps=50)
See Also:
- - :ref:`Experiment Reporting `
+ - :ref:`Logging `
sync_batchnorm
diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py
index caf20a888aaa2..044039f70839d 100644
--- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py
+++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py
@@ -149,7 +149,7 @@ def restore_training_state(self, checkpoint):
self.trainer.global_step = checkpoint['global_step']
self.trainer.current_epoch = checkpoint['epoch']
- # crash if max_epochs is lower than the current epoch from the checkpoint
+ # crash if max_epochs is lower then the current epoch from the checkpoint
if self.trainer.current_epoch > self.trainer.max_epochs:
m = f"""
you restored a checkpoint with current_epoch={self.trainer.current_epoch}
diff --git a/pytorch_lightning/trainer/connectors/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector.py
index e7fcc0c005fe2..6966b95cd8415 100644
--- a/pytorch_lightning/trainer/connectors/logger_connector.py
+++ b/pytorch_lightning/trainer/connectors/logger_connector.py
@@ -34,11 +34,13 @@ def __init__(self, trainer):
self.progress_bar_metrics = {}
self.eval_loop_results = []
- def on_trainer_init(self, logger, log_save_interval, row_log_interval):
+ def on_trainer_init(self, logger, flush_logs_every_n_steps, log_every_n_steps):
# logging
self.configure_logger(logger)
- self.trainer.log_save_interval = log_save_interval
- self.trainer.row_log_interval = row_log_interval
+ # todo: IDE is complaining, these shall be initialized in the Trainer init at leas as placeholders
+ # and assign here the desired value
+ self.trainer.flush_logs_every_n_steps = flush_logs_every_n_steps
+ self.trainer.log_every_n_steps = log_every_n_steps
def configure_logger(self, logger):
if logger is True:
@@ -510,7 +512,7 @@ def __gather_result_across_time_and_optimizers(self, epoch_output):
def log_train_step_metrics(self, batch_output):
# when metrics should be logged
should_log_metrics = (
- (self.trainer.global_step + 1) % self.trainer.row_log_interval == 0 or self.trainer.should_stop
+ (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0 or self.trainer.should_stop
)
if should_log_metrics or self.trainer.fast_dev_run:
# logs user requested information to logger
diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py
index a4ca9b3025cad..5f2cb3a8949c0 100644
--- a/pytorch_lightning/trainer/data_loading.py
+++ b/pytorch_lightning/trainer/data_loading.py
@@ -27,22 +27,18 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.debugging import InternalDebugger
from pytorch_lightning.utilities.model_utils import is_overridden
+from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils
from copy import deepcopy
-
+TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists()
try:
from apex import amp
except ImportError:
amp = None
-try:
+if TPU_AVAILABLE:
import torch_xla
import torch_xla.core.xla_model as xm
- import torch_xla.distributed.xla_multiprocessing as xmp
-except ImportError:
- XLA_AVAILABLE = False
-else:
- XLA_AVAILABLE = True
try:
import horovod.torch as hvd
diff --git a/pytorch_lightning/trainer/deprecated_api.py b/pytorch_lightning/trainer/deprecated_api.py
index 149248e245268..523572098e92b 100644
--- a/pytorch_lightning/trainer/deprecated_api.py
+++ b/pytorch_lightning/trainer/deprecated_api.py
@@ -13,3 +13,42 @@
# limitations under the License.
"""Mirroring deprecated API"""
+from abc import ABC
+
+from pytorch_lightning.utilities import rank_zero_warn
+
+
+class TrainerDeprecatedAPITillVer0_11(ABC):
+ flush_logs_every_n_steps: int
+ log_every_n_steps: int
+
+ def __init__(self):
+ super().__init__() # mixin calls super too
+
+ @property
+ def log_save_interval(self) -> int:
+ """Back compatibility, will be removed in v0.11.0"""
+ rank_zero_warn("Attribute `log_save_interval` is now set by `flush_logs_every_n_steps` since v0.10.0"
+ " and this method will be removed in v0.11.0", DeprecationWarning)
+ return self.flush_logs_every_n_steps
+
+ @log_save_interval.setter
+ def log_save_interval(self, val: int):
+ """Back compatibility, will be removed in v0.11.0"""
+ rank_zero_warn("Attribute `log_save_interval` is now set by `flush_logs_every_n_steps` since v0.10.0"
+ " and this method will be removed in v0.11.0", DeprecationWarning)
+ self.flush_logs_every_n_steps = val
+
+ @property
+ def row_log_interval(self) -> int:
+ """Back compatibility, will be removed in v0.10.0"""
+ rank_zero_warn("Attribute `row_log_interval` is now set by `log_every_n_steps` since v0.10.0"
+ " and this method will be removed in v0.11.0", DeprecationWarning)
+ return self.log_every_n_steps
+
+ @row_log_interval.setter
+ def row_log_interval(self, val: int):
+ """Back compatibility, will be removed in v0.10.0"""
+ rank_zero_warn("Attribute `row_log_interval` is now set by `log_every_n_steps` since v0.10.0"
+ " and this method will be removed in v0.11.0", DeprecationWarning)
+ self.log_every_n_steps = val
diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py
index 5e04173e6df5c..8d65561ad06aa 100644
--- a/pytorch_lightning/trainer/evaluation_loop.py
+++ b/pytorch_lightning/trainer/evaluation_loop.py
@@ -165,7 +165,7 @@ def evaluation_step(self, test_mode, batch, batch_idx, dataloader_idx):
# track batch size for weighted average
is_result_obj = isinstance(output, Result)
if is_result_obj:
- output.track_batch_size(len(batch))
+ output.track_batch_size(batch)
# allow only EvalResult when using structured results (from val_step)
if is_result_obj and not isinstance(output, EvalResult):
@@ -320,7 +320,7 @@ def log_evaluation_step_metrics(self, batch, batch_idx):
if len(results) == 1:
return None
- results.track_batch_size(len(batch))
+ results.track_batch_size(batch)
self.__log_result_step_metrics(results, batch_idx)
return results
diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py
index 219df9f67301d..829101c2707ab 100644
--- a/pytorch_lightning/trainer/trainer.py
+++ b/pytorch_lightning/trainer/trainer.py
@@ -29,6 +29,7 @@
from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin
from pytorch_lightning.trainer.configuration_validator import ConfigValidator
from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
+from pytorch_lightning.trainer.deprecated_api import TrainerDeprecatedAPITillVer0_11
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin
@@ -78,6 +79,7 @@ class Trainer(
TrainerLoggingMixin,
TrainerTrainingTricksMixin,
TrainerDataLoadingMixin,
+ TrainerDeprecatedAPITillVer0_11,
):
def __init__(
self,
@@ -108,8 +110,8 @@ def __init__(
limit_val_batches: Union[int, float] = 1.0,
limit_test_batches: Union[int, float] = 1.0,
val_check_interval: Union[int, float] = 1.0,
- log_save_interval: int = 100,
- row_log_interval: int = 50,
+ flush_logs_every_n_steps: int = 100,
+ log_every_n_steps: int = 50,
distributed_backend: Optional[str] = None,
sync_batchnorm: bool = False,
precision: int = 32,
@@ -129,7 +131,7 @@ def __init__(
prepare_data_per_node: bool = True,
cluster_environment: ClusterEnvironment = None,
amp_backend: str = 'native',
- amp_level: str = 'O2', # backward compatible, todo: remove in v1.0.0
+ amp_level: str = 'O2',
overfit_pct: float = None, # backward compatible, todo: remove in v1.0.0
):
r"""
@@ -143,10 +145,10 @@ def __init__(
amp_level: The optimization level to use (O1, O2, etc...).
- auto_lr_find: If set to True, will `initially` run a learning rate finder,
- trying to optimize initial learning for faster convergence. Sets learning
- rate in self.lr or self.learning_rate in the LightningModule.
- To use a different key, set a string instead of True with the key name.
+ auto_lr_find: If set to True, will make trainer.tune() run a learning rate finder,
+ trying to optimize initial learning for faster convergence. trainer.tune() method will
+ set the suggested learning rate in self.lr or self.learning_rate in the LightningModule.
+ To use a different key set a string instead of True with the key name.
auto_scale_batch_size: If set to True, will `initially` run a batch size
finder trying to find the largest batch size that fits into memory.
@@ -178,10 +180,14 @@ def __init__(
distributed_backend: The distributed backend to use (dp, ddp, ddp2, ddp_spawn, ddp_cpu)
early_stop_callback (:class:`pytorch_lightning.callbacks.EarlyStopping`).
- Deprecated since v0.10.0 and will be removed in v1.0.
+ .. warning:: .. deprecated:: 0.10.0
+
+ Will be removed in v1.0.
fast_dev_run: runs 1 batch of train, test and val to find any bugs (ie: a sort of unit test).
+ flush_logs_every_n_steps: How often to flush logs to disk (defaults to every 100 steps).
+
gpus: number of gpus to train on (int) or which GPUs to train on (list or str) applied per node
gradient_clip_val: 0 means don't clip.
@@ -196,7 +202,12 @@ def __init__(
log_gpu_memory: None, 'min_max', 'all'. Might slow performance
- log_save_interval: Writes logs to disk this often
+ log_every_n_steps: How often to log within steps (defaults to every 50 steps).
+
+ log_save_interval: How often to flush logs to disk.
+ .. warning:: .. deprecated:: 0.10.0
+
+ Use `flush_logs_every_n_steps` instead. Will remove v0.11.0.
prepare_data_per_node: If True, each LOCAL_RANK=0 will call prepare data.
Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data
@@ -235,7 +246,10 @@ def __init__(
resume_from_checkpoint: To resume training from a specific checkpoint pass in the path here.
This can be a URL.
- row_log_interval: How often to add logging rows (does not write to disk)
+ row_log_interval: How often to log within steps.
+ .. warning:: .. deprecated:: 0.10.0
+
+ Use `log_every_n_steps` instead. Will remove v0.11.0.
sync_batchnorm: Synchronize batch norm layers between process groups/whole world.
@@ -262,6 +276,19 @@ def __init__(
"""
super().__init__()
+ # deprecation warnings
+ if row_log_interval is not None:
+ warnings.warn("Argument `row_log_interval` is deprecated in v0.10, use `log_every_n_steps` instead."
+ " It will be removed in v0.11.0.", DeprecationWarning)
+ log_every_n_steps = row_log_interval
+
+ if log_save_interval is not None:
+ warnings.warn(
+ "Argument `log_save_interval` is deprecated in v0.10, use `flush_logs_every_n_steps` instead."
+ " It will be removed in v0.11.0.", DeprecationWarning
+ )
+ flush_logs_every_n_steps = log_save_interval
+
# init connectors
self.dev_debugger = InternalDebugger(self)
self.config_validator = ConfigValidator(self)
@@ -299,7 +326,7 @@ def __init__(
process_position,
default_root_dir,
weights_save_path,
- resume_from_checkpoint
+ resume_from_checkpoint,
)
# hook
@@ -310,18 +337,12 @@ def __init__(
# init data flags
self.data_connector.on_trainer_init(
- check_val_every_n_epoch,
- reload_dataloaders_every_epoch,
- prepare_data_per_node
+ check_val_every_n_epoch, reload_dataloaders_every_epoch, prepare_data_per_node
)
# init training tricks
self.training_tricks_connector.on_trainer_init(
- gradient_clip_val,
- track_grad_norm,
- accumulate_grad_batches,
- truncated_bptt_steps,
- terminate_on_nan
+ gradient_clip_val, track_grad_norm, accumulate_grad_batches, truncated_bptt_steps, terminate_on_nan
)
# init accelerator related flags
@@ -351,7 +372,7 @@ def __init__(
self.profile_connector.on_trainer_init(profiler)
# init logger flags
- self.logger_connector.on_trainer_init(logger, log_save_interval, row_log_interval)
+ self.logger_connector.on_trainer_init(logger, flush_logs_every_n_steps, log_every_n_steps)
# init debugging flags
self.debugging_connector.on_init_start(
@@ -361,7 +382,7 @@ def __init__(
limit_test_batches,
val_check_interval,
overfit_batches,
- fast_dev_run
+ fast_dev_run,
)
# set precision
@@ -377,8 +398,21 @@ def tune(
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
datamodule: Optional[LightningDataModule] = None,
):
- # TODO: temporary, need to decide if tune or separate object
+ r"""
+ Runs routines to tune hyperparameters before training.
+
+ Args:
+ datamodule: A instance of :class:`LightningDataModule`.
+
+ model: Model to tune.
+
+ train_dataloader: A Pytorch DataLoader with training samples. If the model has
+ a predefined train_dataloader method this will be skipped.
+ val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples.
+ If the model has a predefined val_dataloaders method this will be skipped
+
+ """
# setup data, etc...
self.train_loop.setup_fit(model, train_dataloader, val_dataloaders, datamodule)
@@ -511,13 +545,15 @@ def train(self):
met_min_steps = self.global_step >= self.min_steps if self.min_steps else True
if self.should_stop:
- if (met_min_epochs and met_min_steps):
+ if met_min_epochs and met_min_steps:
self.train_loop.on_train_end()
return
else:
- log.info('Trainer was signaled to stop but required minimum epochs'
- f' ({self.min_epochs}) or minimum steps ({self.min_steps}) has'
- ' not been met. Training will continue...')
+ log.info(
+ 'Trainer was signaled to stop but required minimum epochs'
+ f' ({self.min_epochs}) or minimum steps ({self.min_steps}) has'
+ ' not been met. Training will continue...'
+ )
# hook
self.train_loop.on_train_end()
diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py
index 4cd923b45242a..430d4124d0cf4 100644
--- a/pytorch_lightning/trainer/training_loop.py
+++ b/pytorch_lightning/trainer/training_loop.py
@@ -110,6 +110,7 @@ def setup_training(self, model: LightningModule):
if self.trainer.data_parallel:
ref_model = model.module
+ # set the ranks and devices
self.trainer.accelerator_backend.dist.rank = self.trainer.global_rank
self.trainer.accelerator_backend.dist.device = ref_model.device
@@ -125,7 +126,7 @@ def setup_training(self, model: LightningModule):
# log hyper-parameters
if self.trainer.logger is not None:
- # save exp to get started
+ # save exp to get started (this is where the first experiment logs are written)
self.trainer.logger.log_hyperparams(ref_model.hparams)
self.trainer.logger.log_graph(ref_model)
self.trainer.logger.save()
@@ -460,7 +461,7 @@ def on_before_backward(self, batch_idx, optimizer):
def _track_gradient_norm(self):
grad_norm_dict = {}
- if (self.trainer.global_step + 1) % self.trainer.row_log_interval == 0:
+ if (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0:
if float(self.trainer.track_grad_norm) > 0:
model = self.trainer.get_model()
grad_norm_dict = model.grad_norm(self.trainer.track_grad_norm)
@@ -805,7 +806,7 @@ def build_train_args(self, batch, batch_idx, opt_idx, hiddens):
def save_loggers_on_train_batch_end(self):
# when loggers should save to disk
should_save_log = (
- (self.trainer.global_step + 1) % self.trainer.log_save_interval == 0 or self.trainer.should_stop
+ (self.trainer.global_step + 1) % self.trainer.flush_logs_every_n_steps == 0 or self.trainer.should_stop
)
if should_save_log or self.trainer.fast_dev_run:
if self.trainer.is_global_zero and self.trainer.logger is not None:
diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py
index 8b2e05c66b753..ad3a9eb2e55c9 100644
--- a/pytorch_lightning/tuner/batch_size_scaling.py
+++ b/pytorch_lightning/tuner/batch_size_scaling.py
@@ -123,6 +123,7 @@ def __scale_batch_dump_params(trainer):
# Prevent going into infinite loop
trainer.__dumped_params = {
'auto_lr_find': trainer.auto_lr_find,
+ 'current_epoch': trainer.current_epoch,
'max_steps': trainer.max_steps,
'weights_summary': trainer.weights_summary,
'logger': trainer.logger,
@@ -138,6 +139,7 @@ def __scale_batch_dump_params(trainer):
def __scale_batch_reset_params(trainer, model, steps_per_trial):
trainer.auto_scale_batch_size = None # prevent recursion
trainer.auto_lr_find = False # avoid lr find being called multiple times
+ trainer.current_epoch = 0
trainer.max_steps = steps_per_trial # take few steps
trainer.weights_summary = None # not needed before full run
trainer.logger = DummyLogger()
@@ -151,6 +153,7 @@ def __scale_batch_reset_params(trainer, model, steps_per_trial):
def __scale_batch_restore_params(trainer):
trainer.auto_lr_find = trainer.__dumped_params['auto_lr_find']
+ trainer.current_epoch = trainer.__dumped_params['current_epoch']
trainer.max_steps = trainer.__dumped_params['max_steps']
trainer.weights_summary = trainer.__dumped_params['weights_summary']
trainer.logger = trainer.__dumped_params['logger']
diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py
index a3ba2550186a7..f1ca1609eb91a 100644
--- a/pytorch_lightning/tuner/lr_finder.py
+++ b/pytorch_lightning/tuner/lr_finder.py
@@ -13,10 +13,12 @@
# limitations under the License.
import importlib
import os
-from typing import List, Optional, Sequence, Union
+from typing import List, Optional, Sequence, Union, Callable
+from functools import wraps
import numpy as np
import torch
+from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
@@ -165,13 +167,7 @@ def lr_find(
trainer.save_checkpoint(str(save_path))
# Configure optimizer and scheduler
- optimizers, _, _ = trainer.init_optimizers(model)
-
- if len(optimizers) != 1:
- raise MisconfigurationException(
- f'`model.configure_optimizers()` returned {len(optimizers)}, but'
- ' learning rate finder only works with single optimizer')
- model.configure_optimizers = lr_finder._get_new_optimizer(optimizers[0])
+ model.configure_optimizers = lr_finder._exchange_scheduler(model.configure_optimizers)
# Fit, lr & loss logged in callback
trainer.fit(model,
@@ -261,28 +257,47 @@ def __init__(self, mode: str, lr_min: float, lr_max: float, num_training: int):
self.results = {}
self._total_batch_idx = 0 # for debug purpose
- def _get_new_optimizer(self, optimizer: torch.optim.Optimizer):
- """ Construct a new `configure_optimizers()` method, that has a optimizer
- with initial lr set to lr_min and a scheduler that will either
- linearly or exponentially increase the lr to lr_max in num_training steps.
-
- Args:
- optimizer: instance of `torch.optim.Optimizer`
-
+ def _exchange_scheduler(self, configure_optimizers: Callable):
+ """ Decorate configure_optimizers methods such that it returns the users
+ originally specified optimizer together with a new scheduler that
+ that takes care of the learning rate search.
"""
- new_lrs = [self.lr_min] * len(optimizer.param_groups)
- for param_group, new_lr in zip(optimizer.param_groups, new_lrs):
- param_group["lr"] = new_lr
- param_group["initial_lr"] = new_lr
-
- args = (optimizer, self.lr_max, self.num_training)
- scheduler = _LinearLR(*args) if self.mode == 'linear' else _ExponentialLR(*args)
+ @wraps(configure_optimizers)
+ def func():
+ # Decide the structure of the output from configure_optimizers
+ # Same logic as method `init_optimizers` in trainer/optimizers.py
+ optim_conf = configure_optimizers()
+ if isinstance(optim_conf, Optimizer):
+ optimizers = [optim_conf]
+ elif isinstance(optim_conf, (list, tuple)) and len(optim_conf) == 2 \
+ and isinstance(optim_conf[0], list):
+ optimizers, _ = optim_conf
+ elif isinstance(optim_conf, dict):
+ optimizers = [optim_conf["optimizer"]]
+ elif isinstance(optim_conf, (list, tuple)) and isinstance(optim_conf[0], dict):
+ optimizers = [opt_dict["optimizer"] for opt_dict in optim_conf]
+ elif isinstance(optim_conf, (list, tuple)):
+ optimizers = [optim_conf]
+
+ if len(optimizers) != 1:
+ raise MisconfigurationException(
+ f'`model.configure_optimizers()` returned {len(optimizers)}, but'
+ ' learning rate finder only works with single optimizer')
+
+ optimizer = optimizers[0]
+
+ new_lrs = [self.lr_min] * len(optimizer.param_groups)
+ for param_group, new_lr in zip(optimizer.param_groups, new_lrs):
+ param_group["lr"] = new_lr
+ param_group["initial_lr"] = new_lr
+
+ args = (optimizer, self.lr_max, self.num_training)
+ scheduler = _LinearLR(*args) if self.mode == 'linear' else _ExponentialLR(*args)
- def configure_optimizers():
return [optimizer], [{'scheduler': scheduler,
'interval': 'step'}]
- return configure_optimizers
+ return func
def plot(self, suggest: bool = False, show: bool = False):
""" Plot results from lr_find run
diff --git a/pytorch_lightning/tuner/tuning.py b/pytorch_lightning/tuner/tuning.py
index 8c55ffac92c6a..7c2bb65bd4bb2 100644
--- a/pytorch_lightning/tuner/tuning.py
+++ b/pytorch_lightning/tuner/tuning.py
@@ -37,6 +37,42 @@ def scale_batch_size(self,
max_trials: int = 25,
batch_arg_name: str = 'batch_size',
**fit_kwargs):
+ r"""
+ Will iteratively try to find the largest batch size for a given model
+ that does not give an out of memory (OOM) error.
+
+ Args:
+ model: Model to fit.
+
+ mode: string setting the search mode. Either `power` or `binsearch`.
+ If mode is `power` we keep multiplying the batch size by 2, until
+ we get an OOM error. If mode is 'binsearch', we will initially
+ also keep multiplying by 2 and after encountering an OOM error
+ do a binary search between the last successful batch size and the
+ batch size that failed.
+
+ steps_per_trial: number of steps to run with a given batch size.
+ Idealy 1 should be enough to test if a OOM error occurs,
+ however in practise a few are needed
+
+ init_val: initial batch size to start the search with
+
+ max_trials: max number of increase in batch size done before
+ algorithm is terminated
+
+ batch_arg_name: name of the attribute that stores the batch size.
+ It is expected that the user has provided a model or datamodule that has a hyperparameter
+ with that name. We will look for this attribute name in the following places
+
+ - `model`
+ - `model.hparams`
+ - `model.datamodule`
+ - `trainer.datamodule` (the datamodule passed to the tune method)
+
+ **fit_kwargs: remaining arguments to be passed to .fit(), e.g., dataloader
+ or datamodule.
+
+ """
return scale_batch_size(
self.trainer, model, mode, steps_per_trial, init_val, max_trials, batch_arg_name, **fit_kwargs
)
diff --git a/pytorch_lightning/utilities/xla_device_utils.py b/pytorch_lightning/utilities/xla_device_utils.py
new file mode 100644
index 0000000000000..470b4ac33412b
--- /dev/null
+++ b/pytorch_lightning/utilities/xla_device_utils.py
@@ -0,0 +1,74 @@
+import functools
+import importlib
+from multiprocessing import Process, Queue
+
+import torch
+
+TORCHXLA_AVAILABLE = importlib.util.find_spec("torch_xla") is not None
+if TORCHXLA_AVAILABLE:
+ import torch_xla.core.xla_model as xm
+else:
+ xm = None
+
+
+def inner_f(queue, func, **kwargs): # pragma: no cover
+ try:
+ queue.put(func(**kwargs))
+ except Exception as _e:
+ import traceback
+
+ traceback.print_exc()
+ queue.put(None)
+
+
+def pl_multi_process(func):
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ queue = Queue()
+ proc = Process(target=inner_f, args=(queue, func,), kwargs=kwargs)
+ proc.start()
+ proc.join()
+ return queue.get()
+
+ return wrapper
+
+
+class XLADeviceUtils:
+ """Used to detect the type of XLA device"""
+
+ TPU_AVAILABLE = None
+
+ @staticmethod
+ def _fetch_xla_device_type(device: torch.device) -> str:
+ """
+ Returns XLA device type
+ Args:
+ device: (:class:`~torch.device`): Accepts a torch.device type with a XLA device format i.e xla:0
+ Return:
+ Returns a str of the device hardware type. i.e TPU
+ """
+ if xm is not None:
+ return xm.xla_device_hw(device)
+
+ @staticmethod
+ def _is_device_tpu() -> bool:
+ """
+ Check if device is TPU
+ Return:
+ A boolean value indicating if the xla device is a TPU device or not
+ """
+ if xm is not None:
+ device = xm.xla_device()
+ device_type = XLADeviceUtils._fetch_xla_device_type(device)
+ return device_type == "TPU"
+
+ @staticmethod
+ def tpu_device_exists() -> bool:
+ """
+ Public method to check if TPU is available
+ Return:
+ A boolean value indicating if a TPU device exists on the system
+ """
+ if XLADeviceUtils.TPU_AVAILABLE is None and TORCHXLA_AVAILABLE:
+ XLADeviceUtils.TPU_AVAILABLE = pl_multi_process(XLADeviceUtils._is_device_tpu)()
+ return XLADeviceUtils.TPU_AVAILABLE
diff --git a/tests/base/boring_model.py b/tests/base/boring_model.py
index b57db5d33f6d4..a9cf3695cfab8 100644
--- a/tests/base/boring_model.py
+++ b/tests/base/boring_model.py
@@ -3,6 +3,32 @@
from torch.utils.data import Dataset
+class RandomDictDataset(Dataset):
+ def __init__(self, size, length):
+ self.len = length
+ self.data = torch.randn(length, size)
+
+ def __getitem__(self, index):
+ a = self.data[index]
+ b = a + 2
+ return {'a': a, 'b': b}
+
+ def __len__(self):
+ return self.len
+
+
+class RandomDictStringDataset(Dataset):
+ def __init__(self, size, length):
+ self.len = length
+ self.data = torch.randn(length, size)
+
+ def __getitem__(self, index):
+ return {"id": str(index), "x": self.data[index]}
+
+ def __len__(self):
+ return self.len
+
+
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
diff --git a/tests/base/model_optimizers.py b/tests/base/model_optimizers.py
index 9e5f558b43dcd..a3733bec835fa 100644
--- a/tests/base/model_optimizers.py
+++ b/tests/base/model_optimizers.py
@@ -24,6 +24,10 @@ def configure_optimizers__lbfgs(self):
optimizer = optim.LBFGS(self.parameters(), lr=self.learning_rate)
return optimizer
+ def configure_optimizers__adagrad(self):
+ optimizer = optim.Adagrad(self.parameters(), lr=self.learning_rate)
+ return optimizer
+
def configure_optimizers__multiple_optimizers(self):
"""
return whatever optimizers we want here.
diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py
index 8a1daaf695a2f..0e01fee7a8856 100644
--- a/tests/callbacks/test_early_stopping.py
+++ b/tests/callbacks/test_early_stopping.py
@@ -6,7 +6,7 @@
import pytest
import torch
-from pytorch_lightning import Trainer
+from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from tests.base import EvalModelTemplate
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@@ -35,7 +35,7 @@ def test_resume_early_stopping_from_checkpoint(tmpdir):
https://github.com/PyTorchLightning/pytorch-lightning/issues/1464
https://github.com/PyTorchLightning/pytorch-lightning/issues/1463
"""
-
+ seed_everything(42)
model = EvalModelTemplate()
checkpoint_callback = ModelCheckpoint(monitor="early_stop_on", save_top_k=1)
early_stop_callback = EarlyStoppingTestRestore()
@@ -60,7 +60,7 @@ def test_resume_early_stopping_from_checkpoint(tmpdir):
early_stop_callback = EarlyStoppingTestRestore(early_stop_callback_state)
new_trainer = Trainer(
default_root_dir=tmpdir,
- max_epochs=2,
+ max_epochs=1,
resume_from_checkpoint=checkpoint_filepath,
early_stop_callback=early_stop_callback,
)
diff --git a/tests/loggers/test_all.py b/tests/loggers/test_all.py
index 6a9f3aa1c92ad..80453bbb658b2 100644
--- a/tests/loggers/test_all.py
+++ b/tests/loggers/test_all.py
@@ -157,15 +157,23 @@ def name(self):
@pytest.mark.parametrize("logger_class", [
- TensorBoardLogger,
CometLogger,
MLFlowLogger,
NeptuneLogger,
+ TensorBoardLogger,
TestTubeLogger,
# The WandbLogger gets tested for pickling in its own test.
])
-@mock.patch('pytorch_lightning.loggers.neptune.neptune')
-def test_loggers_pickle(neptune, tmpdir, monkeypatch, logger_class):
+def test_loggers_pickle_all(tmpdir, monkeypatch, logger_class):
+ """ Test that the logger objects can be pickled. This test only makes sense if the packages are installed. """
+ _patch_comet_atexit(monkeypatch)
+ try:
+ _test_loggers_pickle(tmpdir, monkeypatch, logger_class)
+ except (ImportError, ModuleNotFoundError):
+ pytest.xfail(f"pickle test requires {logger_class.__class__} dependencies to be installed.")
+
+
+def _test_loggers_pickle(tmpdir, monkeypatch, logger_class):
"""Verify that pickling trainer with logger works."""
_patch_comet_atexit(monkeypatch)
diff --git a/tests/loggers/test_comet.py b/tests/loggers/test_comet.py
index 0e1199e88d27a..0e7bbabaf4e9e 100644
--- a/tests/loggers/test_comet.py
+++ b/tests/loggers/test_comet.py
@@ -15,27 +15,24 @@ def _patch_comet_atexit(monkeypatch):
monkeypatch.setattr(atexit, "register", lambda _: None)
-def test_comet_logger_online():
+@patch('pytorch_lightning.loggers.comet.comet_ml')
+def test_comet_logger_online(comet):
"""Test comet online with mocks."""
# Test api_key given
- with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet:
+ with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet_experiment:
logger = CometLogger(api_key='key', workspace='dummy-test', project_name='general')
_ = logger.experiment
- comet.assert_called_once_with(api_key='key', workspace='dummy-test', project_name='general')
+ comet_experiment.assert_called_once_with(api_key='key', workspace='dummy-test', project_name='general')
# Test both given
- with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet:
+ with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet_experiment:
logger = CometLogger(save_dir='test', api_key='key', workspace='dummy-test', project_name='general')
_ = logger.experiment
- comet.assert_called_once_with(api_key='key', workspace='dummy-test', project_name='general')
-
- # Test neither given
- with pytest.raises(MisconfigurationException):
- CometLogger(workspace='dummy-test', project_name='general')
+ comet_experiment.assert_called_once_with(api_key='key', workspace='dummy-test', project_name='general')
# Test already exists
with patch('pytorch_lightning.loggers.comet.CometExistingExperiment') as comet_existing:
@@ -61,52 +58,73 @@ def test_comet_logger_online():
api.assert_called_once_with('rest')
-def test_comet_logger_experiment_name():
+@patch('pytorch_lightning.loggers.comet.comet_ml')
+def test_comet_logger_no_api_key_given(comet):
+ """ Test that CometLogger fails to initialize if both api key and save_dir are missing. """
+ with pytest.raises(MisconfigurationException):
+ comet.config.get_api_key.return_value = None
+ CometLogger(workspace='dummy-test', project_name='general')
+
+
+@patch('pytorch_lightning.loggers.comet.comet_ml')
+def test_comet_logger_experiment_name(comet):
"""Test that Comet Logger experiment name works correctly."""
api_key = "key"
experiment_name = "My Name"
# Test api_key given
- with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet:
+ with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet_experiment:
logger = CometLogger(api_key=api_key, experiment_name=experiment_name,)
assert logger._experiment is None
_ = logger.experiment
- comet.assert_called_once_with(api_key=api_key, project_name=None)
+ comet_experiment.assert_called_once_with(api_key=api_key, project_name=None)
- comet().set_name.assert_called_once_with(experiment_name)
+ comet_experiment().set_name.assert_called_once_with(experiment_name)
-def test_comet_logger_dirs_creation(tmpdir, monkeypatch):
+@patch('pytorch_lightning.loggers.comet.CometOfflineExperiment')
+@patch('pytorch_lightning.loggers.comet.comet_ml')
+def test_comet_logger_dirs_creation(comet, comet_experiment, tmpdir, monkeypatch):
""" Test that the logger creates the folders and files in the right place. """
_patch_comet_atexit(monkeypatch)
+ comet.config.get_api_key.return_value = None
+ comet.generate_guid.return_value = "4321"
+
logger = CometLogger(project_name='test', save_dir=tmpdir)
assert not os.listdir(tmpdir)
assert logger.mode == 'offline'
assert logger.save_dir == tmpdir
+ assert logger.name == 'test'
+ assert logger.version == "4321"
_ = logger.experiment
- version = logger.version
- assert set(os.listdir(tmpdir)) == {f'{logger.experiment.id}.zip'}
+
+ comet_experiment.assert_called_once_with(offline_directory=tmpdir, project_name='test')
+
+ # mock return values of experiment
+ logger.experiment.id = '1'
+ logger.experiment.project_name = 'test'
model = EvalModelTemplate()
trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=1, limit_val_batches=3)
trainer.fit(model)
- assert trainer.checkpoint_callback.dirpath == (tmpdir / 'test' / version / 'checkpoints')
+ assert trainer.checkpoint_callback.dirpath == (tmpdir / 'test' / "1" / 'checkpoints')
assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0.ckpt'}
-def test_comet_name_default():
+@patch('pytorch_lightning.loggers.comet.comet_ml')
+def test_comet_name_default(comet):
""" Test that CometLogger.name don't create an Experiment and returns a default value. """
api_key = "key"
- with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet:
+ with patch('pytorch_lightning.loggers.comet.CometExperiment'):
logger = CometLogger(api_key=api_key)
assert logger._experiment is None
@@ -116,13 +134,14 @@ def test_comet_name_default():
assert logger._experiment is None
-def test_comet_name_project_name():
+@patch('pytorch_lightning.loggers.comet.comet_ml')
+def test_comet_name_project_name(comet):
""" Test that CometLogger.name does not create an Experiment and returns project name if passed. """
api_key = "key"
project_name = "My Project Name"
- with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet:
+ with patch('pytorch_lightning.loggers.comet.CometExperiment'):
logger = CometLogger(api_key=api_key, project_name=project_name)
assert logger._experiment is None
@@ -132,13 +151,15 @@ def test_comet_name_project_name():
assert logger._experiment is None
-def test_comet_version_without_experiment():
+@patch('pytorch_lightning.loggers.comet.comet_ml')
+def test_comet_version_without_experiment(comet):
""" Test that CometLogger.version does not create an Experiment. """
api_key = "key"
experiment_name = "My Name"
+ comet.generate_guid.return_value = "1234"
- with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet:
+ with patch('pytorch_lightning.loggers.comet.CometExperiment'):
logger = CometLogger(api_key=api_key, experiment_name=experiment_name)
assert logger._experiment is None
@@ -154,15 +175,16 @@ def test_comet_version_without_experiment():
logger.reset_experiment()
- second_version = logger.version
+ second_version = logger.version == "1234"
assert second_version is not None
assert second_version != first_version
-def test_comet_epoch_logging(tmpdir, monkeypatch):
+@patch("pytorch_lightning.loggers.comet.CometExperiment")
+@patch('pytorch_lightning.loggers.comet.comet_ml')
+def test_comet_epoch_logging(comet, comet_experiment, tmpdir, monkeypatch):
""" Test that CometLogger removes the epoch key from the metrics dict and passes it as argument. """
_patch_comet_atexit(monkeypatch)
- with patch("pytorch_lightning.loggers.comet.CometOfflineExperiment.log_metrics") as log_metrics:
- logger = CometLogger(project_name="test", save_dir=tmpdir)
- logger.log_metrics({"test": 1, "epoch": 1}, step=123)
- log_metrics.assert_called_once_with({"test": 1}, epoch=1, step=123)
+ logger = CometLogger(project_name="test", save_dir=tmpdir)
+ logger.log_metrics({"test": 1, "epoch": 1}, step=123)
+ logger.experiment.log_metrics.assert_called_once_with({"test": 1}, epoch=1, step=123)
diff --git a/tests/loggers/test_mlflow.py b/tests/loggers/test_mlflow.py
index e5b871e4ec7be..3676cd05fe027 100644
--- a/tests/loggers/test_mlflow.py
+++ b/tests/loggers/test_mlflow.py
@@ -1,6 +1,8 @@
import os
from unittest import mock
+from unittest.mock import MagicMock
+
from mlflow.tracking import MlflowClient
from pytorch_lightning import Trainer
@@ -8,15 +10,57 @@
from tests.base import EvalModelTemplate
-def test_mlflow_logger_exists(tmpdir):
- """ Test launching two independent loggers. """
+@mock.patch('pytorch_lightning.loggers.mlflow.mlflow')
+@mock.patch('pytorch_lightning.loggers.mlflow.MlflowClient')
+def test_mlflow_logger_exists(client, mlflow, tmpdir):
+ """ Test launching three independent loggers with either same or different experiment name. """
+
+ run1 = MagicMock()
+ run1.info.run_id = "run-id-1"
+
+ run2 = MagicMock()
+ run2.info.run_id = "run-id-2"
+
+ run3 = MagicMock()
+ run3.info.run_id = "run-id-3"
+
+ # simulate non-existing experiment creation
+ client.return_value.get_experiment_by_name = MagicMock(return_value=None)
+ client.return_value.create_experiment = MagicMock(return_value="exp-id-1") # experiment_id
+ client.return_value.create_run = MagicMock(return_value=run1)
+
logger = MLFlowLogger('test', save_dir=tmpdir)
+ assert logger._experiment_id is None
+ assert logger._run_id is None
+ _ = logger.experiment
+ assert logger.experiment_id == "exp-id-1"
+ assert logger.run_id == "run-id-1"
+ assert logger.experiment.create_experiment.asset_called_once()
+ client.reset_mock(return_value=True)
+
+ # simulate existing experiment returns experiment id
+ exp1 = MagicMock()
+ exp1.experiment_id = "exp-id-1"
+ client.return_value.get_experiment_by_name = MagicMock(return_value=exp1)
+ client.return_value.create_run = MagicMock(return_value=run2)
+
# same name leads to same experiment id, but different runs get recorded
logger2 = MLFlowLogger('test', save_dir=tmpdir)
- assert logger.experiment_id == logger2.experiment_id
- assert logger.run_id != logger2.run_id
+ assert logger2.experiment_id == logger.experiment_id
+ assert logger2.run_id == "run-id-2"
+ assert logger2.experiment.create_experiment.call_count == 0
+ assert logger2.experiment.create_run.asset_called_once()
+ client.reset_mock(return_value=True)
+
+ # simulate a 3rd experiment with new name
+ client.return_value.get_experiment_by_name = MagicMock(return_value=None)
+ client.return_value.create_experiment = MagicMock(return_value="exp-id-3")
+ client.return_value.create_run = MagicMock(return_value=run3)
+
+ # logger with new experiment name causes new experiment id and new run id to be created
logger3 = MLFlowLogger('new', save_dir=tmpdir)
- assert logger3.experiment_id != logger.experiment_id
+ assert logger3.experiment_id == "exp-id-3" != logger.experiment_id
+ assert logger3.run_id == "run-id-3"
def test_mlflow_logger_dirs_creation(tmpdir):
diff --git a/tests/models/test_grad_norm.py b/tests/models/test_grad_norm.py
index 0e8dece3e070a..61fb3ae7eb2e2 100644
--- a/tests/models/test_grad_norm.py
+++ b/tests/models/test_grad_norm.py
@@ -59,7 +59,7 @@ def test_grad_tracking(tmpdir, norm_type, rtol=5e-3):
default_root_dir=tmpdir,
max_epochs=3,
track_grad_norm=norm_type,
- row_log_interval=1, # request grad_norms every batch
+ log_every_n_steps=1, # request grad_norms every batch
)
result = trainer.fit(model)
@@ -76,20 +76,20 @@ def test_grad_tracking(tmpdir, norm_type, rtol=5e-3):
assert np.allclose(log, mod, rtol=rtol)
-@pytest.mark.parametrize("row_log_interval", [1, 2, 3])
-def test_grad_tracking_interval(tmpdir, row_log_interval):
+@pytest.mark.parametrize("log_every_n_steps", [1, 2, 3])
+def test_grad_tracking_interval(tmpdir, log_every_n_steps):
""" Test that gradient norms get tracked in the right interval and that everytime the same keys get logged. """
trainer = Trainer(
default_root_dir=tmpdir,
track_grad_norm=2,
- row_log_interval=row_log_interval,
+ log_every_n_steps=log_every_n_steps,
max_steps=10,
)
with patch.object(trainer.logger, "log_metrics") as mocked:
model = EvalModelTemplate()
trainer.fit(model)
- expected = trainer.global_step // row_log_interval
+ expected = trainer.global_step // log_every_n_steps
grad_norm_dicts = []
for _, kwargs in mocked.call_args_list:
metrics = kwargs.get("metrics", {})
diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py
index cddc3db78ac4e..4f328b5d65d94 100644
--- a/tests/models/test_tpu.py
+++ b/tests/models/test_tpu.py
@@ -1,26 +1,27 @@
import os
+from multiprocessing import Process, Queue
import pytest
from torch.utils.data import DataLoader
import tests.base.develop_pipelines as tpipes
from pytorch_lightning import Trainer, seed_everything
+from pytorch_lightning.accelerators.base_backend import BackendType
from pytorch_lightning.accelerators import TPUBackend
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.exceptions import MisconfigurationException
+from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils
from tests.base import EvalModelTemplate
from tests.base.datasets import TrialMNIST
from tests.base.develop_utils import pl_multi_process_test
-try:
+TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists()
+
+if TPU_AVAILABLE:
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
SERIAL_EXEC = xmp.MpSerialExecutor()
-except ImportError:
- TPU_AVAILABLE = False
-else:
- TPU_AVAILABLE = True
_LARGER_DATASET = TrialMNIST(download=True, num_samples=2000, digits=(0, 1, 2, 5, 8))
@@ -216,7 +217,6 @@ def test_tpu_misconfiguration():
Trainer(tpu_cores=[1, 8])
-# @patch('pytorch_lightning.trainer.trainer.XLA_AVAILABLE', False)
@pytest.mark.skipif(TPU_AVAILABLE, reason="test requires missing TPU")
def test_exception_when_no_tpu_found(tmpdir):
"""Test if exception is thrown when xla devices are not available"""
@@ -263,7 +263,7 @@ def test_result_obj_on_tpu(tmpdir):
default_root_dir=tmpdir,
max_epochs=epochs,
callbacks=[EarlyStopping()],
- row_log_interval=2,
+ log_every_n_steps=2,
limit_train_batches=batches,
weights_summary=None,
tpu_cores=8
diff --git a/tests/test_deprecated.py b/tests/test_deprecated.py
index 75aae09fe07e8..3e8639c91233f 100644
--- a/tests/test_deprecated.py
+++ b/tests/test_deprecated.py
@@ -18,13 +18,25 @@ def _soft_unimport_module(str_module):
def test_tbd_remove_in_v0_11_0_trainer():
with pytest.deprecated_call(match='will be removed in v0.11.0'):
- lr_logger = LearningRateLogger()
+ LearningRateLogger()
+
+ with pytest.deprecated_call(match='will be removed in v0.11.0'):
+ trainer = Trainer(row_log_interval=8)
+ assert trainer.log_every_n_steps == 8
+ with pytest.deprecated_call(match='will be removed in v0.11.0'):
+ assert trainer.row_log_interval == 8
+
+ with pytest.deprecated_call(match='will be removed in v0.11.0'):
+ trainer = Trainer(log_save_interval=9)
+ assert trainer.flush_logs_every_n_steps == 9
+ with pytest.deprecated_call(match='will be removed in v0.11.0'):
+ assert trainer.log_save_interval == 9
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
def test_tbd_remove_in_v0_11_0_trainer_gpu():
with pytest.deprecated_call(match='will be removed in v0.11.0'):
- gpu_usage = GpuUsageLogger()
+ GpuUsageLogger()
class ModelVer0_6(EvalModelTemplate):
diff --git a/tests/trainer/data_flow/test_eval_loop_flow_1_0.py b/tests/trainer/data_flow/test_eval_loop_flow_1_0.py
index 4feffca178b81..7c64c3aae2e5c 100644
--- a/tests/trainer/data_flow/test_eval_loop_flow_1_0.py
+++ b/tests/trainer/data_flow/test_eval_loop_flow_1_0.py
@@ -41,7 +41,7 @@ def backward(self, trainer, loss, optimizer, optimizer_idx):
limit_train_batches=2,
limit_val_batches=2,
max_epochs=2,
- row_log_interval=1,
+ log_every_n_steps=1,
weights_summary=None,
)
trainer.fit(model)
@@ -90,7 +90,7 @@ def backward(self, trainer, loss, optimizer, optimizer_idx):
limit_train_batches=2,
limit_val_batches=2,
max_epochs=2,
- row_log_interval=1,
+ log_every_n_steps=1,
weights_summary=None,
)
trainer.fit(model)
@@ -147,7 +147,7 @@ def backward(self, trainer, loss, optimizer, optimizer_idx):
limit_train_batches=2,
limit_val_batches=2,
max_epochs=2,
- row_log_interval=1,
+ log_every_n_steps=1,
weights_summary=None,
)
@@ -211,7 +211,7 @@ def backward(self, trainer, loss, optimizer, optimizer_idx):
limit_train_batches=2,
limit_val_batches=2,
max_epochs=2,
- row_log_interval=1,
+ log_every_n_steps=1,
weights_summary=None,
)
diff --git a/tests/trainer/data_flow/test_train_loop_flow_dict_1_0.py b/tests/trainer/data_flow/test_train_loop_flow_dict_1_0.py
index 71ff0a21c5bfe..0767684169adc 100644
--- a/tests/trainer/data_flow/test_train_loop_flow_dict_1_0.py
+++ b/tests/trainer/data_flow/test_train_loop_flow_dict_1_0.py
@@ -31,7 +31,7 @@ def backward(self, trainer, loss, optimizer, optimizer_idx):
limit_train_batches=2,
limit_val_batches=2,
max_epochs=2,
- row_log_interval=1,
+ log_every_n_steps=1,
weights_summary=None,
)
trainer.fit(model)
@@ -73,7 +73,7 @@ def backward(self, trainer, loss, optimizer, optimizer_idx):
limit_train_batches=2,
limit_val_batches=2,
max_epochs=2,
- row_log_interval=1,
+ log_every_n_steps=1,
weights_summary=None,
)
trainer.fit(model)
@@ -121,7 +121,7 @@ def backward(self, trainer, loss, optimizer, optimizer_idx):
limit_train_batches=2,
limit_val_batches=2,
max_epochs=2,
- row_log_interval=1,
+ log_every_n_steps=1,
weights_summary=None,
)
trainer.fit(model)
@@ -175,7 +175,7 @@ def backward(self, trainer, loss, optimizer, optimizer_idx):
limit_train_batches=2,
limit_val_batches=2,
max_epochs=2,
- row_log_interval=1,
+ log_every_n_steps=1,
weights_summary=None,
)
trainer.fit(model)
diff --git a/tests/trainer/data_flow/test_train_loop_flow_scalar_1_0.py b/tests/trainer/data_flow/test_train_loop_flow_scalar_1_0.py
index ff49fb68d343b..3823dec33fb21 100644
--- a/tests/trainer/data_flow/test_train_loop_flow_scalar_1_0.py
+++ b/tests/trainer/data_flow/test_train_loop_flow_scalar_1_0.py
@@ -33,7 +33,7 @@ def backward(self, trainer, loss, optimizer, optimizer_idx):
limit_train_batches=2,
limit_val_batches=2,
max_epochs=2,
- row_log_interval=1,
+ log_every_n_steps=1,
weights_summary=None,
)
trainer.fit(model)
@@ -75,7 +75,7 @@ def backward(self, trainer, loss, optimizer, optimizer_idx):
limit_train_batches=2,
limit_val_batches=2,
max_epochs=2,
- row_log_interval=1,
+ log_every_n_steps=1,
weights_summary=None,
)
trainer.fit(model)
@@ -123,7 +123,7 @@ def backward(self, trainer, loss, optimizer, optimizer_idx):
limit_train_batches=2,
limit_val_batches=2,
max_epochs=2,
- row_log_interval=1,
+ log_every_n_steps=1,
weights_summary=None,
)
trainer.fit(model)
@@ -177,7 +177,7 @@ def backward(self, trainer, loss, optimizer, optimizer_idx):
limit_train_batches=2,
limit_val_batches=2,
max_epochs=2,
- row_log_interval=1,
+ log_every_n_steps=1,
weights_summary=None,
)
trainer.fit(model)
diff --git a/tests/trainer/legacy_deprecate_flow_log_tests/test_trainer_steps_result_return.py b/tests/trainer/legacy_deprecate_flow_log_tests/test_trainer_steps_result_return.py
index 40236deef7a1e..114dd0a9497b3 100644
--- a/tests/trainer/legacy_deprecate_flow_log_tests/test_trainer_steps_result_return.py
+++ b/tests/trainer/legacy_deprecate_flow_log_tests/test_trainer_steps_result_return.py
@@ -38,7 +38,7 @@ def test_training_step_result_log_step_only(tmpdir):
default_root_dir=tmpdir,
limit_train_batches=batches,
limit_val_batches=batches,
- row_log_interval=1,
+ log_every_n_steps=1,
max_epochs=1,
weights_summary=None,
)
@@ -113,7 +113,7 @@ def test_training_step_result_log_epoch_only(tmpdir):
default_root_dir=tmpdir,
limit_train_batches=batches,
limit_val_batches=batches,
- row_log_interval=1,
+ log_every_n_steps=1,
max_epochs=epochs,
weights_summary=None,
)
@@ -190,7 +190,7 @@ def test_training_step_result_log_step_and_epoch(tmpdir):
default_root_dir=tmpdir,
limit_train_batches=batches,
limit_val_batches=batches,
- row_log_interval=1,
+ log_every_n_steps=1,
max_epochs=epochs,
weights_summary=None,
)
@@ -322,7 +322,7 @@ def test_training_step_epoch_end_result(tmpdir):
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=epochs,
- row_log_interval=1,
+ log_every_n_steps=1,
limit_train_batches=batches,
weights_summary=None,
)
@@ -404,7 +404,7 @@ def test_no_auto_callbacks_with_train_loop_only(tmpdir):
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=epochs,
- row_log_interval=1,
+ log_every_n_steps=1,
limit_train_batches=batches,
weights_summary=None,
)
@@ -422,7 +422,7 @@ def test_no_auto_callbacks_with_train_loop_only(tmpdir):
default_root_dir=tmpdir,
early_stop_callback=True,
max_epochs=epochs,
- row_log_interval=1,
+ log_every_n_steps=1,
limit_train_batches=batches,
weights_summary=None,
)
@@ -447,7 +447,7 @@ def test_no_callbacks_with_train_loop_only(tmpdir):
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=epochs,
- row_log_interval=1,
+ log_every_n_steps=1,
limit_train_batches=batches,
weights_summary=None,
)
@@ -476,7 +476,7 @@ def test_use_callbacks_with_train_loop_only(tmpdir):
default_root_dir=tmpdir,
max_epochs=epochs,
early_stop_callback=True,
- row_log_interval=1,
+ log_every_n_steps=1,
limit_train_batches=batches,
weights_summary=None,
)
@@ -532,7 +532,7 @@ def test_full_train_loop_with_results_obj_dp(tmpdir):
gpus=[0, 1],
max_epochs=epochs,
early_stop_callback=True,
- row_log_interval=2,
+ log_every_n_steps=2,
limit_train_batches=batches,
weights_summary=None,
)
@@ -573,7 +573,7 @@ def test_loop_steps_only_dp(tmpdir):
gpus=[0, 1],
max_epochs=epochs,
early_stop_callback=True,
- row_log_interval=2,
+ log_every_n_steps=2,
limit_train_batches=batches,
weights_summary=None,
)
@@ -613,7 +613,7 @@ def test_result_monitor_warnings(tmpdir):
default_root_dir=tmpdir,
max_epochs=2,
early_stop_callback=True,
- row_log_interval=2,
+ log_every_n_steps=2,
limit_train_batches=2,
weights_summary=None,
checkpoint_callback=ModelCheckpoint(monitor='not_checkpoint_on')
@@ -626,7 +626,7 @@ def test_result_monitor_warnings(tmpdir):
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=2,
- row_log_interval=2,
+ log_every_n_steps=2,
limit_train_batches=2,
weights_summary=None,
early_stop_callback=EarlyStopping(monitor='not_val_loss')
@@ -653,7 +653,7 @@ def test_eval_loop_return_none(tmpdir):
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=2,
- row_log_interval=2,
+ log_every_n_steps=2,
limit_train_batches=2,
weights_summary=None,
)
diff --git a/tests/trainer/legacy_deprecate_flow_log_tests/test_validation_steps_result_return.py b/tests/trainer/legacy_deprecate_flow_log_tests/test_validation_steps_result_return.py
index a43b50c442dac..f8be3b9ea67b0 100644
--- a/tests/trainer/legacy_deprecate_flow_log_tests/test_validation_steps_result_return.py
+++ b/tests/trainer/legacy_deprecate_flow_log_tests/test_validation_steps_result_return.py
@@ -40,7 +40,7 @@ def test_val_step_result_callbacks(tmpdir):
default_root_dir=tmpdir,
max_epochs=epochs,
early_stop_callback=True,
- row_log_interval=1,
+ log_every_n_steps=1,
limit_train_batches=batches,
weights_summary=None,
)
@@ -88,7 +88,7 @@ def test_val_step_using_train_callbacks(tmpdir):
default_root_dir=tmpdir,
max_epochs=epochs,
early_stop_callback=True,
- row_log_interval=1,
+ log_every_n_steps=1,
limit_train_batches=batches,
weights_summary=None,
)
@@ -135,7 +135,7 @@ def test_val_step_only_epoch_metrics(tmpdir):
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=epochs,
- row_log_interval=1,
+ log_every_n_steps=1,
limit_train_batches=batches,
weights_summary=None,
)
@@ -194,7 +194,7 @@ def test_val_step_only_step_metrics(tmpdir):
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=epochs,
- row_log_interval=1,
+ log_every_n_steps=1,
limit_train_batches=batches,
limit_val_batches=batches,
weights_summary=None,
@@ -240,7 +240,7 @@ def test_val_step_epoch_step_metrics(tmpdir):
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=epochs,
- row_log_interval=1,
+ log_every_n_steps=1,
limit_train_batches=batches,
limit_val_batches=batches,
weights_summary=None,
@@ -327,7 +327,7 @@ def test_val_step_epoch_end_result(tmpdir):
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=epochs,
- row_log_interval=1,
+ log_every_n_steps=1,
limit_train_batches=batches,
limit_val_batches=batches,
weights_summary=None,
@@ -390,7 +390,7 @@ def test_val_step_full_loop_result_dp(tmpdir):
gpus=[0, 1],
max_epochs=epochs,
early_stop_callback=True,
- row_log_interval=2,
+ log_every_n_steps=2,
limit_train_batches=batches,
weights_summary=None,
)
@@ -444,7 +444,7 @@ def test_full_loop_result_cpu(tmpdir):
default_root_dir=tmpdir,
max_epochs=epochs,
early_stop_callback=True,
- row_log_interval=2,
+ log_every_n_steps=2,
limit_train_batches=batches,
weights_summary=None,
)
diff --git a/tests/trainer/logging/test_distributed_logging.py b/tests/trainer/logging/test_distributed_logging.py
new file mode 100644
index 0000000000000..1317ea0e5ddd6
--- /dev/null
+++ b/tests/trainer/logging/test_distributed_logging.py
@@ -0,0 +1,59 @@
+import pytest
+import torch
+from tests.base import BoringModel
+import platform
+from distutils.version import LooseVersion
+from pytorch_lightning import Trainer, Callback
+from unittest import mock
+
+
+class TestModel(BoringModel):
+
+ def on_pretrain_routine_end(self) -> None:
+ with mock.patch('pytorch_lightning.loggers.base.LightningLoggerBase.agg_and_log_metrics') as m:
+ self.trainer.logger_connector.log_metrics({'a': 2}, {})
+ logged_times = m.call_count
+ expected = 1 if self.global_rank == 0 else 0
+ assert logged_times == expected, 'actual logger called from non-global zero'
+
+
+@pytest.mark.skipif(platform.system() == "Windows",
+ reason="Distributed training is not supported on Windows")
+@pytest.mark.skipif((platform.system() == "Darwin" and
+ LooseVersion(torch.__version__) < LooseVersion("1.3.0")),
+ reason="Distributed training is not supported on MacOS before Torch 1.3.0")
+def test_global_zero_only_logging_ddp_cpu(tmpdir):
+ """
+ Makes sure logging only happens from root zero
+ """
+ model = TestModel()
+ model.training_epoch_end = None
+ trainer = Trainer(
+ distributed_backend='ddp_cpu',
+ num_processes=2,
+ default_root_dir=tmpdir,
+ limit_train_batches=1,
+ limit_val_batches=1,
+ max_epochs=1,
+ weights_summary=None,
+ )
+ trainer.fit(model)
+
+
+@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
+def test_global_zero_only_logging_ddp_spawn(tmpdir):
+ """
+ Makes sure logging only happens from root zero
+ """
+ model = TestModel()
+ model.training_epoch_end = None
+ trainer = Trainer(
+ distributed_backend='ddp_spawn',
+ gpus=2,
+ default_root_dir=tmpdir,
+ limit_train_batches=1,
+ limit_val_batches=1,
+ max_epochs=1,
+ weights_summary=None,
+ )
+ trainer.fit(model)
diff --git a/tests/trainer/logging/test_eval_loop_logging_1_0.py b/tests/trainer/logging/test_eval_loop_logging_1_0.py
index 831e59965e1a4..7ed857d8565bb 100644
--- a/tests/trainer/logging/test_eval_loop_logging_1_0.py
+++ b/tests/trainer/logging/test_eval_loop_logging_1_0.py
@@ -42,7 +42,7 @@ def backward(self, trainer, loss, optimizer, optimizer_idx):
limit_train_batches=2,
limit_val_batches=2,
max_epochs=2,
- row_log_interval=1,
+ log_every_n_steps=1,
weights_summary=None,
)
trainer.fit(model)
@@ -111,7 +111,7 @@ def backward(self, trainer, loss, optimizer, optimizer_idx):
limit_train_batches=2,
limit_val_batches=2,
max_epochs=2,
- row_log_interval=1,
+ log_every_n_steps=1,
weights_summary=None,
)
trainer.fit(model)
@@ -168,7 +168,7 @@ def validation_epoch_end(self, outputs):
limit_train_batches=batches,
limit_val_batches=batches,
max_epochs=max_epochs,
- row_log_interval=log_interval,
+ log_every_n_steps=log_interval,
weights_summary=None,
)
trainer.fit(model)
diff --git a/tests/trainer/logging/test_train_loop_logging_1_0.py b/tests/trainer/logging/test_train_loop_logging_1_0.py
index 848e6e74a59f2..c4cc8b0ed638f 100644
--- a/tests/trainer/logging/test_train_loop_logging_1_0.py
+++ b/tests/trainer/logging/test_train_loop_logging_1_0.py
@@ -1,7 +1,7 @@
"""
Tests to ensure that the training loop works with a dict (1.0)
"""
-from tests.base.boring_model import BoringModel
+from tests.base.boring_model import BoringModel, RandomDictDataset, RandomDictStringDataset
import os
import torch
import pytest
@@ -64,7 +64,7 @@ def backward(self, trainer, loss, optimizer, optimizer_idx):
limit_train_batches=2,
limit_val_batches=2,
max_epochs=2,
- row_log_interval=1,
+ log_every_n_steps=1,
weights_summary=None,
)
trainer.fit(model)
@@ -137,7 +137,7 @@ def backward(self, trainer, loss, optimizer, optimizer_idx):
limit_train_batches=2,
limit_val_batches=2,
max_epochs=2,
- row_log_interval=1,
+ log_every_n_steps=1,
weights_summary=None,
)
trainer.fit(model)
@@ -149,16 +149,7 @@ def backward(self, trainer, loss, optimizer, optimizer_idx):
# make sure all the metrics are available for callbacks
logged_metrics = set(trainer.logged_metrics.keys())
- expected_logged_metrics = {
- 'epoch',
- 'a',
- 'step_a',
- 'epoch_a',
- 'b',
- 'b1',
- 'a1',
- 'a2'
- }
+ expected_logged_metrics = {'epoch', 'a', 'step_a', 'epoch_a', 'b', 'b1', 'a1', 'a2'}
assert logged_metrics == expected_logged_metrics
pbar_metrics = set(trainer.progress_bar_metrics.keys())
@@ -208,7 +199,7 @@ def training_epoch_end(self, outputs):
limit_train_batches=batches,
limit_val_batches=batches,
max_epochs=max_epochs,
- row_log_interval=log_interval,
+ log_every_n_steps=log_interval,
weights_summary=None,
)
trainer.fit(model)
@@ -352,3 +343,79 @@ def train_dataloader(self):
generated = set(trainer.logged_metrics.keys())
expected = {'a', 'step_a', 'epoch_a', 'epoch'}
assert generated == expected
+
+
+def test_different_batch_types_for_sizing(tmpdir):
+
+ class TestModel(BoringModel):
+ def training_step(self, batch, batch_idx):
+ assert isinstance(batch, dict)
+ a = batch['a']
+ acc = self.step(a)
+ self.log('a', {'d1': 2, 'd2': torch.tensor(1)}, on_step=True, on_epoch=True)
+ return acc
+
+ def validation_step(self, batch, batch_idx):
+ assert isinstance(batch, dict)
+ a = batch['a']
+ output = self.layer(a)
+ loss = self.loss(batch, output)
+ self.log('n', {'d3': 2, 'd4': torch.tensor(1)}, on_step=True, on_epoch=True)
+ return {"x": loss}
+
+ def train_dataloader(self):
+ return torch.utils.data.DataLoader(RandomDictDataset(32, 64), batch_size=32)
+
+ def val_dataloader(self):
+ return torch.utils.data.DataLoader(RandomDictDataset(32, 64), batch_size=32)
+
+ model = TestModel()
+ trainer = Trainer(
+ default_root_dir=tmpdir,
+ limit_train_batches=1,
+ limit_val_batches=2,
+ max_epochs=1,
+ weights_summary=None,
+ )
+ trainer.fit(model)
+
+ generated = set(trainer.logger_connector.logged_metrics)
+ expected = {
+ 'epoch_a', 'a',
+ 'n', 'step_n/epoch_0', 'epoch_n',
+ 'epoch'
+ }
+
+ assert generated == expected
+
+
+def test_validation_step_with_string_data_logging():
+ class TestModel(BoringModel):
+ def on_train_epoch_start(self) -> None:
+ print("override any method to prove your bug")
+
+ def training_step(self, batch, batch_idx):
+ output = self.layer(batch["x"])
+ loss = self.loss(batch, output)
+ return {"loss": loss}
+
+ def validation_step(self, batch, batch_idx):
+ output = self.layer(batch["x"])
+ loss = self.loss(batch, output)
+ self.log("x", loss)
+ return {"x": loss}
+
+ # fake data
+ train_data = torch.utils.data.DataLoader(RandomDictStringDataset(32, 64))
+ val_data = torch.utils.data.DataLoader(RandomDictStringDataset(32, 64))
+
+ # model
+ model = TestModel()
+ trainer = Trainer(
+ default_root_dir=os.getcwd(),
+ limit_train_batches=1,
+ limit_val_batches=1,
+ max_epochs=1,
+ weights_summary=None,
+ )
+ trainer.fit(model, train_data, val_data)
diff --git a/tests/trainer/test_correct_freq_accumulation.py b/tests/trainer/test_correct_freq_accumulation.py
index 9403bf14e9a8e..18561fe17c051 100644
--- a/tests/trainer/test_correct_freq_accumulation.py
+++ b/tests/trainer/test_correct_freq_accumulation.py
@@ -28,7 +28,7 @@ def test_training_step_scalar(tmpdir):
limit_train_batches=2,
limit_val_batches=2,
max_epochs=2,
- row_log_interval=1,
+ log_every_n_steps=1,
weights_summary=None,
)
trainer.fit(model)
diff --git a/tests/trainer/test_lr_finder.py b/tests/trainer/test_lr_finder.py
index 67c673df1318d..67eb480a71c61 100755
--- a/tests/trainer/test_lr_finder.py
+++ b/tests/trainer/test_lr_finder.py
@@ -131,11 +131,14 @@ def test_trainer_arg_str(tmpdir, use_hparams):
'Learning rate was not altered after running learning rate finder'
-def test_call_to_trainer_method(tmpdir):
+@pytest.mark.parametrize('optimizer', ['Adam', 'Adagrad'])
+def test_call_to_trainer_method(tmpdir, optimizer):
""" Test that directly calling the trainer method works """
hparams = EvalModelTemplate.get_default_hparams()
model = EvalModelTemplate(**hparams)
+ if optimizer == 'adagrad':
+ model.configure_optimizers = model.configure_optimizers__adagrad
before_lr = hparams.get('learning_rate')
# logger file to get meta
diff --git a/tests/trainer/test_optimizers.py b/tests/trainer/test_optimizers.py
index 33f0ca8cede5d..34112e7cbd2ad 100644
--- a/tests/trainer/test_optimizers.py
+++ b/tests/trainer/test_optimizers.py
@@ -1,9 +1,10 @@
import pytest
import torch
-from pytorch_lightning import Trainer
+from pytorch_lightning import Trainer, Callback
from tests.base import EvalModelTemplate
from pytorch_lightning.utilities.exceptions import MisconfigurationException
+from tests.base.boring_model import BoringModel
def test_optimizer_with_scheduling(tmpdir):
@@ -298,3 +299,50 @@ def test_init_optimizers_during_testing(tmpdir):
assert len(trainer.lr_schedulers) == 0
assert len(trainer.optimizers) == 0
assert len(trainer.optimizer_frequencies) == 0
+
+
+def test_multiple_optimizers_callbacks(tmpdir):
+ """
+ Tests that multiple optimizers can be used with callbacks
+ """
+ class CB(Callback):
+
+ def on_train_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
+ pass
+
+ def on_train_epoch_start(self, trainer, pl_module):
+ pass
+
+ class TestModel(BoringModel):
+ def __init__(self):
+ super().__init__()
+ self.layer_1 = torch.nn.Linear(32, 2)
+ self.layer_2 = torch.nn.Linear(32, 2)
+
+ def training_step(self, batch, batch_idx, optimizer_idx):
+ if optimizer_idx == 0:
+ a = batch[0]
+ acc = self.layer_1(a)
+ else:
+ a = batch[0]
+ acc = self.layer_2(a)
+
+ acc = self.loss(acc, acc)
+ return acc
+
+ def configure_optimizers(self):
+ a = torch.optim.RMSprop(self.layer_1.parameters(), 1e-2)
+ b = torch.optim.RMSprop(self.layer_2.parameters(), 1e-2)
+ return a, b
+
+ model = TestModel()
+ model.training_epoch_end = None
+ trainer = Trainer(
+ callbacks=[CB()],
+ default_root_dir=tmpdir,
+ limit_train_batches=1,
+ limit_val_batches=2,
+ max_epochs=1,
+ weights_summary=None,
+ )
+ trainer.fit(model)
diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py
index 6ec369a002d20..9ca9deb7f5d32 100644
--- a/tests/trainer/test_trainer.py
+++ b/tests/trainer/test_trainer.py
@@ -1186,12 +1186,12 @@ def setup(self, stage):
pytest.param(3, 10, 5),
])
@patch("pytorch_lightning.loggers.tensorboard.TensorBoardLogger.log_metrics")
-def test_row_log_interval(log_metrics_mock, tmpdir, train_batches, max_steps, log_interval):
+def test_log_every_n_steps(log_metrics_mock, tmpdir, train_batches, max_steps, log_interval):
model = EvalModelTemplate()
trainer = Trainer(
default_root_dir=tmpdir,
- row_log_interval=log_interval,
- log_save_interval=log_interval,
+ log_every_n_steps=log_interval,
+ flush_logs_every_n_steps=log_interval,
limit_train_batches=train_batches,
limit_val_batches=0,
max_steps=max_steps,
diff --git a/tests/trainer/test_trainer_tricks.py b/tests/trainer/test_trainer_tricks.py
index a9297576c6f14..3ee47522d9258 100755
--- a/tests/trainer/test_trainer_tricks.py
+++ b/tests/trainer/test_trainer_tricks.py
@@ -182,7 +182,8 @@ def test_trainer_reset_correctly(tmpdir):
'callbacks',
'checkpoint_callback',
'early_stop_callback',
- 'limit_train_batches']
+ 'limit_train_batches',
+ 'current_epoch']
attributes_before = {}
for ca in changed_attributes:
diff --git a/tests/utilities/test_xla_device_utils.py b/tests/utilities/test_xla_device_utils.py
new file mode 100644
index 0000000000000..f90fa750666bc
--- /dev/null
+++ b/tests/utilities/test_xla_device_utils.py
@@ -0,0 +1,31 @@
+import pytest
+
+from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils as xdu
+from tests.base.develop_utils import pl_multi_process_test
+
+try:
+ import torch_xla.core.xla_model as xm
+ XLA_AVAILABLE = True
+except ImportError as e:
+ XLA_AVAILABLE = False
+
+
+@pytest.mark.skipif(XLA_AVAILABLE, reason="test requires torch_xla to be absent")
+def test_tpu_device_absence():
+ """Check tpu_device_exists returns None when torch_xla is not available"""
+ assert xdu.tpu_device_exists() is None
+
+
+@pytest.mark.skipif(not XLA_AVAILABLE, reason="test requires torch_xla to be installed")
+def test_tpu_device_presence():
+ """Check tpu_device_exists returns True when TPU is available"""
+ assert xdu.tpu_device_exists() is True
+
+
+@pytest.mark.skipif(not XLA_AVAILABLE, reason="test requires torch_xla to be installed")
+@pl_multi_process_test
+def test_xla_device_is_a_tpu():
+ """Check that the XLA device is a TPU"""
+ device = xm.xla_device()
+ device_type = xm.xla_device_hw(device)
+ return device_type == "TPU"