From 1be09d3c63346caeb090ddc212924f8b9cace645 Mon Sep 17 00:00:00 2001 From: Mauricio Villegas Date: Thu, 27 May 2021 19:31:32 +0200 Subject: [PATCH] Added save_config_filename init argument to LightningCLI --- CHANGELOG.md | 3 +++ pytorch_lightning/utilities/cli.py | 7 +++++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4259684748e85..f5a5330e7bfb2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -50,6 +50,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `should_rank_save_checkpoint` property to Training Plugins ([#7684](https://github.com/PyTorchLightning/pytorch-lightning/pull/7684)) +- Added `save_config_filename` init argument to `LightningCLI` to ease resolving name conflicts ([#7741](https://github.com/PyTorchLightning/pytorch-lightning/pull/7741)) + + ### Changed - Changed calling of `untoggle_optimizer(opt_idx)` out of the closure function ([#7563](https://github.com/PyTorchLightning/pytorch-lightning/pull/7563) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 5dccad4ab9135..5142864a57af3 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -76,7 +76,7 @@ def __init__( self, parser: LightningArgumentParser, config: Union[Namespace, Dict[str, Any]], - config_filename: str = 'config.yaml' + config_filename: str, ) -> None: self.parser = parser self.config = config @@ -96,6 +96,7 @@ def __init__( model_class: Type[LightningModule], datamodule_class: Type[LightningDataModule] = None, save_config_callback: Type[SaveConfigCallback] = SaveConfigCallback, + save_config_filename: str = 'config.yaml', trainer_class: Type[Trainer] = Trainer, trainer_defaults: Dict[str, Any] = None, seed_everything_default: int = None, @@ -154,6 +155,7 @@ def __init__( self.model_class = model_class self.datamodule_class = datamodule_class self.save_config_callback = save_config_callback + self.save_config_filename = save_config_filename self.trainer_class = trainer_class self.trainer_defaults = {} if trainer_defaults is None else trainer_defaults self.seed_everything_default = seed_everything_default @@ -241,7 +243,8 @@ def instantiate_trainer(self) -> None: else: self.config_init['trainer']['callbacks'].append(self.trainer_defaults['callbacks']) if self.save_config_callback is not None: - self.config_init['trainer']['callbacks'].append(self.save_config_callback(self.parser, self.config)) + config_callback = self.save_config_callback(self.parser, self.config, self.save_config_filename) + self.config_init['trainer']['callbacks'].append(config_callback) self.trainer = self.trainer_class(**self.config_init['trainer']) def prepare_fit_kwargs(self) -> None: