Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP - do not merge: documenting clashing of WandB logger and LightningCLI clash #7675

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@ channels:
- pytorch-nightly

dependencies:
- python>=3.6
- python=3.8
- pip>20.1
- cudatoolkit=10.2
- numpy>=1.16.4
- pytorch>=1.4
- pytorch=1.8.1
- future>=0.17.1
- PyYAML>=5.1
- tqdm>=4.41.0
Expand All @@ -54,3 +55,5 @@ dependencies:
- horovod>=0.21.2
- onnxruntime>=1.3.0
- gym>=0.17.0
- jsonargparse[signatures]
- pytorch-lightning==1.3.2
24 changes: 22 additions & 2 deletions pl_examples/basic_examples/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

import pytorch_lightning as pl
from pl_examples import _DATASETS_PATH, _TORCHVISION_MNIST_AVAILABLE, cli_lightning_logo
from pytorch_lightning.utilities.cli import LightningCLI
from pytorch_lightning.utilities.cli import LightningCLI, SaveConfigCallback
from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE

if _TORCHVISION_AVAILABLE:
Expand Down Expand Up @@ -114,8 +114,28 @@ def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=self.batch_size)


class WandBandSafeConfigCallBackFixCLI(LightningCLI):

def add_arguments_to_parser(self, parser):
parser.add_argument(
'save_config_callback_filename',
default='',
help='Change config filename in order to avoid clashes with wandb.'
'TODO submit PR for setting this in LightningCLI constructor/self.config'
)

def before_fit(self):
save_config_cb = [c for c in self.trainer.callbacks if isinstance(c, SaveConfigCallback)]
if save_config_cb:
config_filename = self.config.get('save_config_callback_filename', '')
if config_filename:
save_config_cb[0].config_filename = config_filename


def cli_main():
cli = LightningCLI(LitAutoEncoder, MyDataModule, seed_everything_default=1234)
# cli = LightningCLI(LitAutoEncoder, MyDataModule, seed_everything_default=1234)
cli = WandBandSafeConfigCallBackFixCLI(LitAutoEncoder, MyDataModule, seed_everything_default=1234)

cli.trainer.test(cli.model, datamodule=cli.datamodule)


Expand Down
14 changes: 14 additions & 0 deletions pl_examples/basic_examples/autoencoder.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# See SaveConfigCallback config_filename: str = 'config.yaml'
# This breaks the example
save_config_callback_filename: '' # Keeps the LightningCLI intact - do not apply workaround
#
# Workaround - UNCOMMENT TO MAKE THE EXAMPLE WORKING or specify as CLI option
# save_config_callback_filename: 'another-name-config.yaml # See WandBandSafeConfigCallBackFixCLI


trainer:
max_epochs: 3
logger:
class_path: pytorch_lightning.loggers.WandbLogger
init_args:
name: oplatek-pl-documenting-wandb-lightningCLI-clash
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#!/bin/bash
# on master branch 01109cdf0c44a150c262b65e70a7e1e64003cf93 commit

ARGS_EXTRA_DDP=" --trainer.gpus 2 --trainer.accelerator ddp"
ARGS_EXTRA_AMP=" --trainer.precision 16"

# conda created with the following command in the git root directory
# conda env create --file environment.yml --prefix $PWD/env
# conda activate $PWD/env
# pip install pytorch-lightning==1.3.2
conda activate ../env
python basic_examples/autoencoder.py --config basic_examples/autoencoder.yml ${ARGS_EXTRA_DDP} ${ARGS_EXTRA_AMP} $@