Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Override add_argparse_args in the FlashTrainer
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed May 26, 2021
1 parent 2113c11 commit 4ad5bcf
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 3 deletions.
14 changes: 11 additions & 3 deletions flash/core/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from argparse import ArgumentParser
from functools import wraps
from typing import Callable, List, Optional, Union

import torch
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from pytorch_lightning import LightningDataModule, LightningModule
from pytorch_lightning import Trainer as PlTrainer
from pytorch_lightning.callbacks import BaseFinetuning
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.argparse import get_init_arguments_and_types, parse_env_variables
from pytorch_lightning.utilities.argparse import add_argparse_args, get_init_arguments_and_types, parse_env_variables
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch.utils.data import DataLoader

Expand Down Expand Up @@ -49,7 +51,7 @@ def insert_env_defaults(self, *args, **kwargs):
return insert_env_defaults


class Trainer(Trainer):
class Trainer(PlTrainer):

@_defaults_from_env_vars
def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -172,3 +174,9 @@ def _merge_callbacks(old_callbacks: List, new_callbacks: List) -> List:
override_types = new_callbacks_types.intersection(old_callbacks_types)
new_callbacks.extend(c for c in old_callbacks if type(c) not in override_types)
return new_callbacks

@classmethod
def add_argparse_args(cls, *args, **kwargs) -> ArgumentParser:
# the lightning trainer implementation does not support subclasses.
# context: https://github.com/PyTorchLightning/lightning-flash/issues/342#issuecomment-848892447
return add_argparse_args(PlTrainer, *args, **kwargs)
8 changes: 8 additions & 0 deletions tests/core/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from argparse import ArgumentParser
from typing import Any

import pytest
Expand Down Expand Up @@ -110,3 +111,10 @@ def test_resolve_callbacks_override_warning(tmpdir):
task = FinetuneClassificationTask(model, loss_fn=F.nll_loss)
with pytest.warns(UserWarning, match="The model contains a default finetune callback"):
trainer._resolve_callbacks(task, "test")


def test_add_argparse_args():
parser = ArgumentParser()
parser = Trainer.add_argparse_args(parser)
args = parser.parse_args(['--gpus=1'])
assert args.gpus == 1

0 comments on commit 4ad5bcf

Please sign in to comment.