From ce617a71ee9e8345649cb3aa5b6c5e218563e1f5 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Tue, 23 Nov 2021 16:59:13 +0000 Subject: [PATCH] Fixes --- flash/core/utilities/flash_cli.py | 25 ++++++++++++++++++++++++- flash/tabular/forecasting/cli.py | 8 ++++++-- 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/flash/core/utilities/flash_cli.py b/flash/core/utilities/flash_cli.py index 673c07196c..77814dad6b 100644 --- a/flash/core/utilities/flash_cli.py +++ b/flash/core/utilities/flash_cli.py @@ -76,6 +76,14 @@ def wrapper(*args, **kwargs): return wrapper +def get_kwarg_name(func) -> Optional[str]: + sig = signature(func) + var_kwargs = [p for p in sig.parameters.values() if p.kind == p.VAR_KEYWORD] + if len(var_kwargs) == 1: + return var_kwargs[0].name + return None + + def make_args_optional(cls, args: Set[str]): @wraps(cls) def wrapper(*args, **kwargs): @@ -220,8 +228,23 @@ def add_subcommand_from_function(self, subcommands, function, function_name=None fail_untyped=False, skip=get_overlapping_args(datamodule_function, input_transform_function), ) - else: + elif get_kwarg_name(function) == "data_module_kwargs": datamodule_function = class_from_function(function, return_type=self.local_datamodule_class) + subcommand.add_class_arguments( + datamodule_function, + fail_untyped=False, + skip={ + "self", + "train_dataset", + "val_dataset", + "test_dataset", + "predict_dataset", + "input", + "input_transform", + }, + ) + else: + datamodule_function = class_from_function(drop_kwargs(function), return_type=self.local_datamodule_class) subcommand.add_class_arguments(datamodule_function, fail_untyped=False) subcommand_name = function_name or function.__name__ subcommands.add_subcommand(subcommand_name, subcommand) diff --git a/flash/tabular/forecasting/cli.py b/flash/tabular/forecasting/cli.py index 44548587b8..f2260a68bf 100644 --- a/flash/tabular/forecasting/cli.py +++ b/flash/tabular/forecasting/cli.py @@ -32,7 +32,9 @@ def from_synthetic_ar_data( n_series: int = 100, max_encoder_length: int = 60, max_prediction_length: int = 20, - **data_module_kwargs, + batch_size: int = 4, + num_workers: int = 0, + **time_series_dataset_kwargs, ) -> TabularForecastingData: """Creates and loads a synthetic auto-regressive (AR) data set.""" data = generate_ar_data(seasonality=seasonality, timesteps=timesteps, n_series=n_series, seed=42) @@ -51,7 +53,9 @@ def from_synthetic_ar_data( max_prediction_length=max_prediction_length, train_data_frame=data[lambda x: x.time_idx <= training_cutoff], val_data_frame=data, - **data_module_kwargs, + batch_size=batch_size, + num_workers=num_workers, + **time_series_dataset_kwargs, )