Skip to content

Commit

Permalink
Weekly Patch Release v.1.2.7 [full merge, no squash] (#6850)
Browse files Browse the repository at this point in the history
* update readme by v1.2.x (#6728)

* [bugfix] Add support for omegaconf and tpu (#6741)

* fix_hydra

* update changelog

Co-authored-by: Your Name <[email protected]>

* [docs] Update Bolts link (#6743)

* Update Bolts link

* Update Bolts link

* formt

Co-authored-by: Jirka Borovec <[email protected]>

* Update logic for checking TPUs availability (#6767)

* Update logic for checking TPUs availability

* fix flake8

* add fix

* resolve bug (#6781)

* Fix validation progress counter with check_val_every_n_epoch > 1 (#5952)

Co-authored-by: rohitgr7 <[email protected]>
Co-authored-by: Carlos Mocholí <[email protected]>

* Remove extinct parameters from lightning_module.rst (#6801)

Fixes  #6800

* Update TPU docs for installation (#6794)

* fix boolean check on iterable dataset when len not defined (#6828)

* fix iterable dataset len check

* update predict and validate

* add validate to test

* add changelog

* add predict

* Sanitize `None` params during pruning (#6836)

* sanitize none params during pruning

* amend

* Fix `unfreeze_and_add_param_group` expects `modules` rather than `module` (#6822)

* Enforce an epoch scheduler interval when using SWA (#6588)

Co-authored-by: Carlos Mocholi <[email protected]>

* Fix DPP + SyncBN (#6838)

* Fix DPP + SyncBN

Ensure that model is already on correct GPU before applying SyncBN conversion

* Fix order of SyncBN for ddp_spawn

* [Fix] TPU Training Type Plugin (#6816)

* Fix support for symlink save_dir in TensorBoardLogger (#6730)

* Add test for symlink support and initial fix

* Respond to comment and add docstring

* Update CHANGELOG.md

* Simplify

* Update pytorch_lightning/utilities/cloud_io.py

Co-authored-by: Carlos Mocholí <[email protected]>

* Make `LightningLocalFileSystem` protected

Co-authored-by: Carlos Mocholí <[email protected]>

* Fixed missing arguments in `lr_find` call (#6784)

There seem to be 3 arguments missing in the `lr_find` call in the tunining.py file.

* Update Changelog & version

* Fix TPU tests for checkpoint

Skip advanced profiler for torch > 1.8

Skip pytorch profiler for torch > 1.8

Fix save checkpoint logic for TPUs

Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: thomas chaton <[email protected]>
Co-authored-by: Your Name <[email protected]>
Co-authored-by: Akihiro Nitta <[email protected]>
Co-authored-by: Yuan-Hang Zhang <[email protected]>
Co-authored-by: rohitgr7 <[email protected]>
Co-authored-by: Carlos Mocholí <[email protected]>
Co-authored-by: Elizaveta Logacheva <[email protected]>
Co-authored-by: Adrian Wälchli <[email protected]>
Co-authored-by: Karthik Prasad <[email protected]>
Co-authored-by: Sadiq Jaffer <[email protected]>
Co-authored-by: Michael Baumgartner <[email protected]>
Co-authored-by: Eugene Khvedchenya <[email protected]>
Co-authored-by: Ethan Harris <[email protected]>
Co-authored-by: Tharindu Hasthika <[email protected]>
  • Loading branch information
16 people authored Apr 7, 2021
1 parent e7abd8e commit b2c7345
Show file tree
Hide file tree
Showing 34 changed files with 299 additions and 277 deletions.
12 changes: 12 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,18 @@ All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [1.2.7] - 2021-04-06

### Fixed

- Fixed resolve a bug with omegaconf and xm.save ([#6741](https://github.com/PyTorchLightning/pytorch-lightning/pull/6741))
- Fixed an issue with IterableDataset when __len__ is not defined ([#6828](https://github.com/PyTorchLightning/pytorch-lightning/pull/6828))
- Sanitize None params during pruning ([#6836](https://github.com/PyTorchLightning/pytorch-lightning/pull/6836))
- Enforce an epoch scheduler interval when using SWA ([#6588](https://github.com/PyTorchLightning/pytorch-lightning/pull/6588))
- Fixed TPU Colab hang issue, post training ([#6816](https://github.com/PyTorchLightning/pytorch-lightning/pull/6816))
- Fixed a bug where `TensorBoardLogger` would give a warning and not log correctly to a symbolic link `save_dir` ([#6730](https://github.com/PyTorchLightning/pytorch-lightning/pull/6730))


## [1.2.6] - 2021-03-30

### Changed
Expand Down
51 changes: 19 additions & 32 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,19 +91,6 @@ Lightning is rigurously tested across multiple GPUs, TPUs CPUs and against major
</center>
</details>

<details>
<summary>Bleeding edge build status (1.2)</summary>

<center>

![CI base testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20base%20testing/badge.svg?branch=release%2F1.2-dev&event=push)
![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=release%2F1.2-dev&event=push)
![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg?branch=release%2F1.2-dev&event=push)
![TPU tests](https://github.com/PyTorchLightning/pytorch-lightning/workflows/TPU%20tests/badge.svg?branch=release%2F1.2-dev&event=push)
![Docs check](https://github.com/PyTorchLightning/pytorch-lightning/workflows/Docs%20check/badge.svg?branch=release%2F1.2-dev&event=push)
</center>
</details>

---

## How To Use
Expand Down Expand Up @@ -132,22 +119,22 @@ pip install pytorch-lightning
conda install pytorch-lightning -c conda-forge
```

#### Install stable - future 1.1.x
#### Install stable 1.2.x

the actual status of 1.1 [stable] is following:
the actual status of 1.2 [stable] is following:

![CI base testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20base%20testing/badge.svg?branch=release%2F1.1.x&event=push)
![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=release%2F1.1.x&event=push)
![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg?branch=release%2F1.1.x&event=push)
![TPU tests](https://github.com/PyTorchLightning/pytorch-lightning/workflows/TPU%20tests/badge.svg?branch=release%2F1.1.x&event=push)
![Docs check](https://github.com/PyTorchLightning/pytorch-lightning/workflows/Docs%20check/badge.svg?branch=release%2F1.1.x&event=push)
![CI base testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20base%20testing/badge.svg?branch=release%2F1.2.x&event=push)
![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=release%2F1.2.x&event=push)
![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg?branch=release%2F1.2.x&event=push)
![TPU tests](https://github.com/PyTorchLightning/pytorch-lightning/workflows/TPU%20tests/badge.svg?branch=release%2F1.2.x&event=push)
![Docs check](https://github.com/PyTorchLightning/pytorch-lightning/workflows/Docs%20check/badge.svg?branch=release%2F1.2.x&event=push)

Install future release from the source
```bash
pip install git+https://github.com/PytorchLightning/pytorch-lightning.git@release/1.1.x --upgrade
pip install git+https://github.com/PytorchLightning/pytorch-lightning.git@release/1.2.x --upgrade
```

#### Install bleeding-edge - future 1.2
#### Install bleeding-edge - future 1.3

Install nightly from the source (no guarantees)
```bash
Expand Down Expand Up @@ -356,27 +343,27 @@ class LitAutoEncoder(pl.LightningModule):
- [MNIST on TPUs](https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/06-mnist-tpu-training.ipynb)

###### Contrastive Learning
- [BYOL](https://pytorch-lightning-bolts.readthedocs.io/en/latest/self_supervised_models.html#byol)
- [CPC v2](https://pytorch-lightning-bolts.readthedocs.io/en/latest/self_supervised_models.html#cpc-v2)
- [Moco v2](https://pytorch-lightning-bolts.readthedocs.io/en/latest/self_supervised_models.html#moco-v2)
- [SIMCLR](https://pytorch-lightning-bolts.readthedocs.io/en/latest/self_supervised_models.html#simclr)
- [BYOL](https://lightning-bolts.readthedocs.io/en/latest/self_supervised_models.html#byol)
- [CPC v2](https://lightning-bolts.readthedocs.io/en/latest/self_supervised_models.html#cpc-v2)
- [Moco v2](https://lightning-bolts.readthedocs.io/en/latest/self_supervised_models.html#moco-v2)
- [SIMCLR](https://lightning-bolts.readthedocs.io/en/latest/self_supervised_models.html#simclr)

###### NLP
- [BERT](https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/04-transformers-text-classification.ipynb)
- [GPT-2](https://pytorch-lightning-bolts.readthedocs.io/en/latest/convolutional.html#gpt-2)
- [GPT-2](https://lightning-bolts.readthedocs.io/en/latest/convolutional.html#gpt-2)


###### Reinforcement Learning
- [DQN](https://pytorch-lightning-bolts.readthedocs.io/en/latest/reinforce_learn.html#dqn-models)
- [Dueling-DQN](https://pytorch-lightning-bolts.readthedocs.io/en/latest/reinforce_learn.html#dueling-dqn)
- [Reinforce](https://pytorch-lightning-bolts.readthedocs.io/en/latest/reinforce_learn.html#reinforce)
- [DQN](https://lightning-bolts.readthedocs.io/en/latest/reinforce_learn.html#dqn-models)
- [Dueling-DQN](https://lightning-bolts.readthedocs.io/en/latest/reinforce_learn.html#dueling-dqn)
- [Reinforce](https://lightning-bolts.readthedocs.io/en/latest/reinforce_learn.html#reinforce)

###### Vision
- [GAN](https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/03-basic-gan.ipynb)

###### Classic ML
- [Logistic Regression](https://pytorch-lightning-bolts.readthedocs.io/en/latest/classic_ml.html#logistic-regression)
- [Linear Regression](https://pytorch-lightning-bolts.readthedocs.io/en/latest/classic_ml.html#linear-regression)
- [Logistic Regression](https://lightning-bolts.readthedocs.io/en/latest/classic_ml.html#logistic-regression)
- [Linear Regression](https://lightning-bolts.readthedocs.io/en/latest/classic_ml.html#linear-regression)

---

Expand Down
3 changes: 1 addition & 2 deletions docs/source/advanced/tpu.rst
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@ To get a TPU on colab, follow these steps:

.. code-block::
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version 1.7 --apt-packages libomp5 libopenblas-dev
!pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.8-cp37-cp37m-linux_x86_64.whl
5. Once the above is done, install PyTorch Lightning (v 0.7.0+).

Expand Down
24 changes: 0 additions & 24 deletions docs/source/common/lightning_module.rst
Original file line number Diff line number Diff line change
Expand Up @@ -907,30 +907,6 @@ use_amp
~~~~~~~
True if using Automatic Mixed Precision (AMP)

------------

use_ddp
~~~~~~~
True if using ddp

------------

use_ddp2
~~~~~~~~
True if using ddp2

------------

use_dp
~~~~~~
True if using dp

------------

use_tpu
~~~~~~~
True if using TPUs

--------------

automatic_optimization
Expand Down
4 changes: 2 additions & 2 deletions docs/source/ecosystem/bolts.rst
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
Bolts
=====
`PyTorch Lightning Bolts <https://pytorch-lightning-bolts.readthedocs.io/en/latest/>`_, is our official collection
`PyTorch Lightning Bolts <https://lightning-bolts.readthedocs.io/en/latest/>`_, is our official collection
of prebuilt models across many research domains.

.. code-block:: bash
pip install pytorch-lightning-bolts
pip install lightning-bolts
In bolts we have:

Expand Down
10 changes: 5 additions & 5 deletions docs/source/extensions/callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,10 @@ Examples
--------
You can do pretty much anything with callbacks.

- `Add a MLP to fine-tune self-supervised networks <https://pytorch-lightning-bolts.readthedocs.io/en/latest/self_supervised_callbacks.html#sslonlineevaluator>`_.
- `Find how to modify an image input to trick the classification result <https://pytorch-lightning-bolts.readthedocs.io/en/latest/vision_callbacks.html#confused-logit>`_.
- `Interpolate the latent space of any variational model <https://pytorch-lightning-bolts.readthedocs.io/en/latest/variational_callbacks.html#latent-dim-interpolator>`_.
- `Log images to Tensorboard for any model <https://pytorch-lightning-bolts.readthedocs.io/en/latest/vision_callbacks.html#tensorboard-image-generator>`_.
- `Add a MLP to fine-tune self-supervised networks <https://lightning-bolts.readthedocs.io/en/latest/self_supervised_callbacks.html#sslonlineevaluator>`_.
- `Find how to modify an image input to trick the classification result <https://lightning-bolts.readthedocs.io/en/latest/vision_callbacks.html#confused-logit>`_.
- `Interpolate the latent space of any variational model <https://lightning-bolts.readthedocs.io/en/latest/variational_callbacks.html#latent-dim-interpolator>`_.
- `Log images to Tensorboard for any model <https://lightning-bolts.readthedocs.io/en/latest/vision_callbacks.html#tensorboard-image-generator>`_.


--------------
Expand All @@ -85,7 +85,7 @@ Lightning has a few built-in callbacks.

.. note::
For a richer collection of callbacks, check out our
`bolts library <https://pytorch-lightning-bolts.readthedocs.io/en/latest/callbacks.html>`_.
`bolts library <https://lightning-bolts.readthedocs.io/en/latest/callbacks.html>`_.

.. currentmodule:: pytorch_lightning.callbacks

Expand Down
16 changes: 8 additions & 8 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,14 @@ PyTorch Lightning Documentation

ecosystem/pytorch_ecoystem
ecosystem/community_examples
Autoencoder <https://pytorch-lightning-bolts.readthedocs.io/en/latest/autoencoders.html#autoencoders>
BYOL <https://pytorch-lightning-bolts.readthedocs.io/en/latest/self_supervised_models.html#byol>
DQN <https://pytorch-lightning-bolts.readthedocs.io/en/latest/reinforce_learn.html#deep-q-network-dqn>
GAN <https://pytorch-lightning-bolts.readthedocs.io/en/latest/gans.html#basic-gan>
GPT-2 <https://pytorch-lightning-bolts.readthedocs.io/en/latest/convolutional.html#gpt-2>
Image-GPT <https://pytorch-lightning-bolts.readthedocs.io/en/latest/convolutional.html#image-gpt>
SimCLR <https://pytorch-lightning-bolts.readthedocs.io/en/latest/self_supervised_models.html#simclr>
VAE <https://pytorch-lightning-bolts.readthedocs.io/en/latest/autoencoders.html#basic-vae>
Autoencoder <https://lightning-bolts.readthedocs.io/en/latest/autoencoders.html#autoencoders>
BYOL <https://lightning-bolts.readthedocs.io/en/latest/self_supervised_models.html#byol>
DQN <https://lightning-bolts.readthedocs.io/en/latest/reinforce_learn.html#deep-q-network-dqn>
GAN <https://lightning-bolts.readthedocs.io/en/latest/gans.html#basic-gan>
GPT-2 <https://lightning-bolts.readthedocs.io/en/latest/convolutional.html#gpt-2>
Image-GPT <https://lightning-bolts.readthedocs.io/en/latest/convolutional.html#image-gpt>
SimCLR <https://lightning-bolts.readthedocs.io/en/latest/self_supervised_models.html#simclr>
VAE <https://lightning-bolts.readthedocs.io/en/latest/autoencoders.html#basic-vae>

.. toctree::
:maxdepth: 1
Expand Down
4 changes: 1 addition & 3 deletions docs/source/starter/introduction_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -572,9 +572,7 @@ Next, install the required xla library (adds support for PyTorch on TPUs)

.. code-block:: shell
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version nightly --apt-packages libomp5 libopenblas-dev
!pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.8-cp37-cp37m-linux_x86_64.whl
In distributed training (multiple GPUs and multiple TPU cores) each GPU or TPU core will run a copy
of this program. This means that without taking any care you will download the dataset N times which
Expand Down
2 changes: 1 addition & 1 deletion notebooks/07-cifar10-baseline.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
"id": "ziAQCrE-TYWG"
},
"source": [
"! pip install pytorch-lightning pytorch-lightning-bolts -qU"
"! pip install pytorch-lightning lightning-bolts -qU"
],
"execution_count": null,
"outputs": []
Expand Down
4 changes: 2 additions & 2 deletions pl_examples/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Examples
Our most robust examples showing all sorts of implementations
can be found in our sister library [PyTorch-Lightning-Bolts](https://pytorch-lightning-bolts.readthedocs.io/en/latest/convolutional.html#gpt-2).
can be found in our sister library [PyTorch-Lightning-Bolts](https://lightning-bolts.readthedocs.io/en/latest/convolutional.html#gpt-2).

---

Expand All @@ -15,5 +15,5 @@ In this folder we add 3 simple examples:

## Domain examples
This folder contains older examples. You should instead use the examples
in [PyTorch-Lightning-Bolts](https://pytorch-lightning-bolts.readthedocs.io/en/latest/convolutional.html#gpt-2)
in [PyTorch-Lightning-Bolts](https://lightning-bolts.readthedocs.io/en/latest/convolutional.html#gpt-2)
for advanced use cases.
2 changes: 1 addition & 1 deletion pl_examples/basic_examples/conv_sequential_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def instantiate_datamodule(args):
if __name__ == "__main__":
cli_lightning_logo()

assert _BOLTS_AVAILABLE, "Bolts is required for this example, install it via pip install pytorch-lightning-bolts"
assert _BOLTS_AVAILABLE, "Bolts is required for this example, install it via `pip install lightning-bolts`"
assert _FAIRSCALE_PIPE_AVAILABLE, "FairScale and PyTorch 1.6 is required for this example."

parser = ArgumentParser(description="Pipe Example")
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def finetune_function(self, pl_module, current_epoch, optimizer, optimizer_idx):
# When `current_epoch` is 10, feature_extractor will start training.
if current_epoch == self._unfreeze_at_epoch:
self.unfreeze_and_add_param_group(
module=pl_module.feature_extractor,
modules=pl_module.feature_extractor,
optimizer=optimizer,
train_bn=True,
)
Expand Down
5 changes: 3 additions & 2 deletions pytorch_lightning/callbacks/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,10 @@ def total_val_batches(self) -> int:
validation dataloader is of infinite size.
"""
total_val_batches = 0
if not self.trainer.disable_validation:
is_val_epoch = (self.trainer.current_epoch) % self.trainer.check_val_every_n_epoch == 0
if self.trainer.enable_validation:
is_val_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0
total_val_batches = sum(self.trainer.num_val_batches) if is_val_epoch else 0

return total_val_batches

@property
Expand Down
4 changes: 3 additions & 1 deletion pytorch_lightning/callbacks/pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,9 @@ def sanitize_parameters_to_prune(
current_modules = [m for m in pl_module.modules() if not isinstance(m, _MODULE_CONTAINERS)]

if parameters_to_prune is None:
parameters_to_prune = [(m, p) for p in parameters for m in current_modules if hasattr(m, p)]
parameters_to_prune = [
(m, p) for p in parameters for m in current_modules if getattr(m, p, None) is not None
]
elif (
isinstance(parameters_to_prune, (list, tuple)) and len(parameters_to_prune) > 0
and all(len(p) == 2 for p in parameters_to_prune)
Expand Down
7 changes: 4 additions & 3 deletions pytorch_lightning/callbacks/swa.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,14 +189,15 @@ def on_train_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningMo
anneal_strategy=self._annealing_strategy,
last_epoch=trainer.max_epochs if self._annealing_strategy == "cos" else -1
)
_scheduler_config = _get_default_scheduler_config()
assert _scheduler_config["interval"] == "epoch" and _scheduler_config["frequency"] == 1
_scheduler_config["scheduler"] = self._swa_scheduler

if trainer.lr_schedulers:
lr_scheduler = trainer.lr_schedulers[0]["scheduler"]
rank_zero_warn(f"Swapping lr_scheduler {lr_scheduler} for {self._swa_scheduler}")
trainer.lr_schedulers[0]["scheduler"] = self._swa_scheduler
trainer.lr_schedulers[0] = _scheduler_config
else:
_scheduler_config = _get_default_scheduler_config()
_scheduler_config["scheduler"] = self._swa_scheduler
trainer.lr_schedulers.append(_scheduler_config)

self.n_averaged = torch.tensor(0, dtype=torch.long, device=pl_module.device)
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/info.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import time

_this_year = time.strftime("%Y")
__version__ = '1.2.6'
__version__ = '1.2.7'
__author__ = 'William Falcon et al.'
__author_email__ = '[email protected]'
__license__ = 'Apache-2.0'
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,12 +257,12 @@ def pre_dispatch(self):
self.dist.rank = self.global_rank
self.dist.device = self.root_device

if self.sync_batchnorm:
self.model = self.configure_sync_batchnorm(self.model)

# move the model to the correct device
self.model_to_device()

if self.sync_batchnorm:
self.model = self.configure_sync_batchnorm(self.model)

self.configure_ddp()

self.barrier()
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,12 +148,12 @@ def new_process(self, process_idx, trainer, mp_queue):
self.dist.rank = self.global_rank
self.dist.device = self.root_device

if self.sync_batchnorm:
self.model = self.configure_sync_batchnorm(self.model)

# move the model to the correct device
self.model_to_device()

if self.sync_batchnorm:
self.model = self.configure_sync_batchnorm(self.model)

self.configure_ddp()

self.barrier()
Expand Down
Loading

0 comments on commit b2c7345

Please sign in to comment.