Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Fix Flash Zero datamodule kwargs #994

Merged
merged 4 commits into from
Nov 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed a bug where translation metrics were not computed correctly ([#992](https://github.com/PyTorchLightning/lightning-flash/pull/992))

- Fixed a bug where additional `DataModule` keyword arguments could not be configured with Flash Zero for some tasks ([#994](https://github.com/PyTorchLightning/lightning-flash/pull/994))

### Removed

- Removed `OutputMapping` ([#939](https://github.com/PyTorchLightning/lightning-flash/pull/939))
Expand Down
31 changes: 25 additions & 6 deletions flash/core/utilities/flash_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,14 @@ def wrapper(*args, **kwargs):
return wrapper


def get_kwarg_name(func) -> Optional[str]:
sig = signature(func)
var_kwargs = [p for p in sig.parameters.values() if p.kind == p.VAR_KEYWORD]
if len(var_kwargs) == 1:
return var_kwargs[0].name
return None


def make_args_optional(cls, args: Set[str]):
@wraps(cls)
def wrapper(*args, **kwargs):
Expand Down Expand Up @@ -211,22 +219,33 @@ def add_arguments_to_parser(self, parser) -> None:

def add_subcommand_from_function(self, subcommands, function, function_name=None):
subcommand = ArgumentParser()
datamodule_function = class_from_function(drop_kwargs(function), return_type=self.local_datamodule_class)
subcommand.add_class_arguments(datamodule_function, fail_untyped=False)
if self.legacy:
datamodule_function = class_from_function(drop_kwargs(function), return_type=self.local_datamodule_class)
subcommand.add_class_arguments(datamodule_function, fail_untyped=False)
input_transform_function = class_from_function(drop_kwargs(self.local_datamodule_class.input_transform_cls))
subcommand.add_class_arguments(
input_transform_function,
fail_untyped=False,
skip=get_overlapping_args(datamodule_function, input_transform_function),
)
else:
base_datamodule_function = class_from_function(drop_kwargs(self.local_datamodule_class))
elif get_kwarg_name(function) == "data_module_kwargs":
datamodule_function = class_from_function(function, return_type=self.local_datamodule_class)
subcommand.add_class_arguments(
base_datamodule_function,
datamodule_function,
fail_untyped=False,
skip=get_overlapping_args(datamodule_function, base_datamodule_function),
skip={
"self",
"train_dataset",
Borda marked this conversation as resolved.
Show resolved Hide resolved
"val_dataset",
"test_dataset",
"predict_dataset",
"input",
"input_transform",
},
)
else:
datamodule_function = class_from_function(drop_kwargs(function), return_type=self.local_datamodule_class)
subcommand.add_class_arguments(datamodule_function, fail_untyped=False)
subcommand_name = function_name or function.__name__
subcommands.add_subcommand(subcommand_name, subcommand)
self._subcommand_builders[subcommand_name] = function
Expand Down
8 changes: 6 additions & 2 deletions flash/tabular/forecasting/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ def from_synthetic_ar_data(
n_series: int = 100,
max_encoder_length: int = 60,
max_prediction_length: int = 20,
**data_module_kwargs,
batch_size: int = 4,
num_workers: int = 0,
**time_series_dataset_kwargs,
Borda marked this conversation as resolved.
Show resolved Hide resolved
) -> TabularForecastingData:
"""Creates and loads a synthetic auto-regressive (AR) data set."""
data = generate_ar_data(seasonality=seasonality, timesteps=timesteps, n_series=n_series, seed=42)
Expand All @@ -51,7 +53,9 @@ def from_synthetic_ar_data(
max_prediction_length=max_prediction_length,
train_data_frame=data[lambda x: x.time_idx <= training_cutoff],
val_data_frame=data,
**data_module_kwargs,
batch_size=batch_size,
num_workers=num_workers,
**time_series_dataset_kwargs,
)


Expand Down