Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Fix from_argparse_args #380

Merged
merged 3 commits into from
Jun 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `flash.Trainer.add_argparse_args` not adding any arguments ([#343](https://github.com/PyTorchLightning/lightning-flash/pull/343))
- Fixed a bug where using `val_split` with `overfit_batches` would give an infinite recursion ([#375](https://github.com/PyTorchLightning/lightning-flash/pull/375))
- Fixed a bug where some timm models were mistakenly given a `global_pool` argument ([#376](https://github.com/PyTorchLightning/lightning-flash/pull/376))
- Fixed `flash.Trainer.from_argparse_args` not passing arguments correctly ([#380](https://github.com/PyTorchLightning/lightning-flash/pull/380))


## [0.3.0] - 2021-05-20
Expand Down
25 changes: 24 additions & 1 deletion flash/core/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import warnings
from argparse import ArgumentParser
from argparse import ArgumentParser, Namespace
from functools import wraps
from typing import Callable, List, Optional, Union

Expand All @@ -29,6 +30,22 @@
from flash.core.finetuning import _DEFAULTS_FINETUNE_STRATEGIES, instantiate_default_finetuning_callbacks


def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs):
"""Modified version of ``pytorch_lightning.utilities.argparse.from_argparse_args`` which populates ``valid_kwargs``
from ``pytorch_lightning.Trainer``."""
if isinstance(args, ArgumentParser):
args = cls.parse_argparser(args)

params = vars(args)

# we only want to pass in valid PLTrainer args, the rest may be user specific
valid_kwargs = inspect.signature(PlTrainer.__init__).parameters
trainer_kwargs = dict((name, params[name]) for name in valid_kwargs if name in params)
trainer_kwargs.update(**kwargs)

return cls(**trainer_kwargs)


def _defaults_from_env_vars(fn: Callable) -> Callable:
"""Copy of ``pytorch_lightning.trainer.connectors.env_vars_connector._defaults_from_env_vars``. Required to fix
build error in readthedocs."""
Expand Down Expand Up @@ -180,3 +197,9 @@ def add_argparse_args(cls, *args, **kwargs) -> ArgumentParser:
# the lightning trainer implementation does not support subclasses.
# context: https://github.com/PyTorchLightning/lightning-flash/issues/342#issuecomment-848892447
return add_argparse_args(PlTrainer, *args, **kwargs)

@classmethod
def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs) -> 'Trainer':
# the lightning trainer implementation does not support subclasses.
# context: https://github.com/PyTorchLightning/lightning-flash/issues/342#issuecomment-848892447
return from_argparse_args(PlTrainer, args, **kwargs)
8 changes: 8 additions & 0 deletions tests/core/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,11 @@ def test_add_argparse_args():
parser = Trainer.add_argparse_args(parser)
args = parser.parse_args(['--gpus=1'])
assert args.gpus == 1


def test_from_argparse_args():
parser = ArgumentParser()
parser = Trainer.add_argparse_args(parser)
args = parser.parse_args(['--max_epochs=200'])
trainer = Trainer.from_argparse_args(args)
assert trainer.max_epochs == 200