From 4ad5bcfc020883a823b57825ac57e84be64c47d4 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 26 May 2021 18:05:18 +0200 Subject: [PATCH 1/2] Override `add_argparse_args` in the `FlashTrainer` --- flash/core/trainer.py | 14 +++++++++++--- tests/core/test_trainer.py | 8 ++++++++ 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/flash/core/trainer.py b/flash/core/trainer.py index 5b5eccfe45..942533a97d 100644 --- a/flash/core/trainer.py +++ b/flash/core/trainer.py @@ -12,14 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. import warnings +from argparse import ArgumentParser from functools import wraps from typing import Callable, List, Optional, Union import torch -from pytorch_lightning import LightningDataModule, LightningModule, Trainer +from pytorch_lightning import LightningDataModule, LightningModule +from pytorch_lightning import Trainer as PlTrainer from pytorch_lightning.callbacks import BaseFinetuning from pytorch_lightning.utilities import rank_zero_warn -from pytorch_lightning.utilities.argparse import get_init_arguments_and_types, parse_env_variables +from pytorch_lightning.utilities.argparse import add_argparse_args, get_init_arguments_and_types, parse_env_variables from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch.utils.data import DataLoader @@ -49,7 +51,7 @@ def insert_env_defaults(self, *args, **kwargs): return insert_env_defaults -class Trainer(Trainer): +class Trainer(PlTrainer): @_defaults_from_env_vars def __init__(self, *args, **kwargs): @@ -172,3 +174,9 @@ def _merge_callbacks(old_callbacks: List, new_callbacks: List) -> List: override_types = new_callbacks_types.intersection(old_callbacks_types) new_callbacks.extend(c for c in old_callbacks if type(c) not in override_types) return new_callbacks + + @classmethod + def add_argparse_args(cls, *args, **kwargs) -> ArgumentParser: + # the lightning trainer implementation does not support subclasses. + # context: https://github.com/PyTorchLightning/lightning-flash/issues/342#issuecomment-848892447 + return add_argparse_args(PlTrainer, *args, **kwargs) diff --git a/tests/core/test_trainer.py b/tests/core/test_trainer.py index f639386d98..c63750b8ae 100644 --- a/tests/core/test_trainer.py +++ b/tests/core/test_trainer.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from argparse import ArgumentParser from typing import Any import pytest @@ -110,3 +111,10 @@ def test_resolve_callbacks_override_warning(tmpdir): task = FinetuneClassificationTask(model, loss_fn=F.nll_loss) with pytest.warns(UserWarning, match="The model contains a default finetune callback"): trainer._resolve_callbacks(task, "test") + + +def test_add_argparse_args(): + parser = ArgumentParser() + parser = Trainer.add_argparse_args(parser) + args = parser.parse_args(['--gpus=1']) + assert args.gpus == 1 From 0366f44c04bf2b633d0822abc1ca963bf4fce858 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 26 May 2021 18:11:13 +0200 Subject: [PATCH 2/2] Update CHANGELOG --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8772672c51..58dfd5cb70 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [0.3.1] - YYYY-MM-DD +### Fixed + +- Fixed `flash.Trainer.add_argparse_args` not adding any arguments ([#343](https://github.com/PyTorchLightning/lightning-flash/pull/343)) + ## [0.3.0] - 2021-05-20