diff --git a/CHANGELOG.md b/CHANGELOG.md index 66d0ab8c1a..11d4b0accf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed a bug where additional kwargs (e.g. sampler) passed to tabular data would be ignored ([#792](https://github.com/PyTorchLightning/lightning-flash/pull/792)) + ## [0.5.0] - 2021-09-07 diff --git a/flash/tabular/data.py b/flash/tabular/data.py index 6a6f87fa5a..b078344366 100644 --- a/flash/tabular/data.py +++ b/flash/tabular/data.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. from io import StringIO -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union import numpy as np from pytorch_lightning.utilities.exceptions import MisconfigurationException +from torch.utils.data.sampler import Sampler from flash.core.classification import LabelsState from flash.core.data.callback import BaseDataFetcher @@ -344,6 +345,7 @@ def from_data_frame( val_split: Optional[float] = None, batch_size: int = 4, num_workers: int = 0, + sampler: Optional[Type[Sampler]] = None, **preprocess_kwargs: Any, ): """Creates a :class:`~flash.tabular.data.TabularData` object from the given data frames. @@ -372,6 +374,7 @@ def from_data_frame( val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + sampler: The ``sampler`` to use for the ``train_dataloader``. preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used if ``preprocess = None``. @@ -420,6 +423,7 @@ def from_data_frame( val_split=val_split, batch_size=batch_size, num_workers=num_workers, + sampler=sampler, cat_cols=categorical_fields, num_cols=numerical_fields, target_col=target_fields, @@ -451,6 +455,7 @@ def from_csv( val_split: Optional[float] = None, batch_size: int = 4, num_workers: int = 0, + sampler: Optional[Type[Sampler]] = None, **preprocess_kwargs: Any, ) -> "DataModule": """Creates a :class:`~flash.tabular.data.TabularData` object from the given CSV files. @@ -479,6 +484,7 @@ def from_csv( val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + sampler: The ``sampler`` to use for the ``train_dataloader``. preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used if ``preprocess = None``. @@ -506,4 +512,6 @@ def from_csv( val_split=val_split, batch_size=batch_size, num_workers=num_workers, + sampler=sampler, + **preprocess_kwargs, )