diff --git a/CHANGELOG.md b/CHANGELOG.md index 1d2dde3140..f8cc8447d8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -50,6 +50,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a bug where translation metrics were not computed correctly ([#992](https://github.com/PyTorchLightning/lightning-flash/pull/992)) +- Fixed a bug where additional `DataModule` keyword arguments could not be configured with Flash Zero for some tasks ([#994](https://github.com/PyTorchLightning/lightning-flash/pull/994)) + - Fixed a bug where the TabularForecaster would not work with some versions of pandas ([#995](https://github.com/PyTorchLightning/lightning-flash/pull/995)) ### Removed diff --git a/flash/core/utilities/flash_cli.py b/flash/core/utilities/flash_cli.py index bd0992afba..77814dad6b 100644 --- a/flash/core/utilities/flash_cli.py +++ b/flash/core/utilities/flash_cli.py @@ -76,6 +76,14 @@ def wrapper(*args, **kwargs): return wrapper +def get_kwarg_name(func) -> Optional[str]: + sig = signature(func) + var_kwargs = [p for p in sig.parameters.values() if p.kind == p.VAR_KEYWORD] + if len(var_kwargs) == 1: + return var_kwargs[0].name + return None + + def make_args_optional(cls, args: Set[str]): @wraps(cls) def wrapper(*args, **kwargs): @@ -211,22 +219,33 @@ def add_arguments_to_parser(self, parser) -> None: def add_subcommand_from_function(self, subcommands, function, function_name=None): subcommand = ArgumentParser() - datamodule_function = class_from_function(drop_kwargs(function), return_type=self.local_datamodule_class) - subcommand.add_class_arguments(datamodule_function, fail_untyped=False) if self.legacy: + datamodule_function = class_from_function(drop_kwargs(function), return_type=self.local_datamodule_class) + subcommand.add_class_arguments(datamodule_function, fail_untyped=False) input_transform_function = class_from_function(drop_kwargs(self.local_datamodule_class.input_transform_cls)) subcommand.add_class_arguments( input_transform_function, fail_untyped=False, skip=get_overlapping_args(datamodule_function, input_transform_function), ) - else: - base_datamodule_function = class_from_function(drop_kwargs(self.local_datamodule_class)) + elif get_kwarg_name(function) == "data_module_kwargs": + datamodule_function = class_from_function(function, return_type=self.local_datamodule_class) subcommand.add_class_arguments( - base_datamodule_function, + datamodule_function, fail_untyped=False, - skip=get_overlapping_args(datamodule_function, base_datamodule_function), + skip={ + "self", + "train_dataset", + "val_dataset", + "test_dataset", + "predict_dataset", + "input", + "input_transform", + }, ) + else: + datamodule_function = class_from_function(drop_kwargs(function), return_type=self.local_datamodule_class) + subcommand.add_class_arguments(datamodule_function, fail_untyped=False) subcommand_name = function_name or function.__name__ subcommands.add_subcommand(subcommand_name, subcommand) self._subcommand_builders[subcommand_name] = function diff --git a/flash/tabular/forecasting/cli.py b/flash/tabular/forecasting/cli.py index 44548587b8..f2260a68bf 100644 --- a/flash/tabular/forecasting/cli.py +++ b/flash/tabular/forecasting/cli.py @@ -32,7 +32,9 @@ def from_synthetic_ar_data( n_series: int = 100, max_encoder_length: int = 60, max_prediction_length: int = 20, - **data_module_kwargs, + batch_size: int = 4, + num_workers: int = 0, + **time_series_dataset_kwargs, ) -> TabularForecastingData: """Creates and loads a synthetic auto-regressive (AR) data set.""" data = generate_ar_data(seasonality=seasonality, timesteps=timesteps, n_series=n_series, seed=42) @@ -51,7 +53,9 @@ def from_synthetic_ar_data( max_prediction_length=max_prediction_length, train_data_frame=data[lambda x: x.time_idx <= training_cutoff], val_data_frame=data, - **data_module_kwargs, + batch_size=batch_size, + num_workers=num_workers, + **time_series_dataset_kwargs, )