From 0b014a97017410479ed52bc87ea2d88c5507f180 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Wed, 22 Sep 2021 21:49:20 +0100 Subject: [PATCH] Add sampler argument to tabular data --- flash/tabular/data.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/flash/tabular/data.py b/flash/tabular/data.py index 6a6f87fa5a..b078344366 100644 --- a/flash/tabular/data.py +++ b/flash/tabular/data.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. from io import StringIO -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union import numpy as np from pytorch_lightning.utilities.exceptions import MisconfigurationException +from torch.utils.data.sampler import Sampler from flash.core.classification import LabelsState from flash.core.data.callback import BaseDataFetcher @@ -344,6 +345,7 @@ def from_data_frame( val_split: Optional[float] = None, batch_size: int = 4, num_workers: int = 0, + sampler: Optional[Type[Sampler]] = None, **preprocess_kwargs: Any, ): """Creates a :class:`~flash.tabular.data.TabularData` object from the given data frames. @@ -372,6 +374,7 @@ def from_data_frame( val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + sampler: The ``sampler`` to use for the ``train_dataloader``. preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used if ``preprocess = None``. @@ -420,6 +423,7 @@ def from_data_frame( val_split=val_split, batch_size=batch_size, num_workers=num_workers, + sampler=sampler, cat_cols=categorical_fields, num_cols=numerical_fields, target_col=target_fields, @@ -451,6 +455,7 @@ def from_csv( val_split: Optional[float] = None, batch_size: int = 4, num_workers: int = 0, + sampler: Optional[Type[Sampler]] = None, **preprocess_kwargs: Any, ) -> "DataModule": """Creates a :class:`~flash.tabular.data.TabularData` object from the given CSV files. @@ -479,6 +484,7 @@ def from_csv( val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + sampler: The ``sampler`` to use for the ``train_dataloader``. preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used if ``preprocess = None``. @@ -506,4 +512,6 @@ def from_csv( val_split=val_split, batch_size=batch_size, num_workers=num_workers, + sampler=sampler, + **preprocess_kwargs, )