Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Add sampler argument to tabular data
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris committed Sep 22, 2021
1 parent f943441 commit 0b014a9
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion flash/tabular/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from io import StringIO
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union

import numpy as np
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch.utils.data.sampler import Sampler

from flash.core.classification import LabelsState
from flash.core.data.callback import BaseDataFetcher
Expand Down Expand Up @@ -344,6 +345,7 @@ def from_data_frame(
val_split: Optional[float] = None,
batch_size: int = 4,
num_workers: int = 0,
sampler: Optional[Type[Sampler]] = None,
**preprocess_kwargs: Any,
):
"""Creates a :class:`~flash.tabular.data.TabularData` object from the given data frames.
Expand Down Expand Up @@ -372,6 +374,7 @@ def from_data_frame(
val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
sampler: The ``sampler`` to use for the ``train_dataloader``.
preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used
if ``preprocess = None``.
Expand Down Expand Up @@ -420,6 +423,7 @@ def from_data_frame(
val_split=val_split,
batch_size=batch_size,
num_workers=num_workers,
sampler=sampler,
cat_cols=categorical_fields,
num_cols=numerical_fields,
target_col=target_fields,
Expand Down Expand Up @@ -451,6 +455,7 @@ def from_csv(
val_split: Optional[float] = None,
batch_size: int = 4,
num_workers: int = 0,
sampler: Optional[Type[Sampler]] = None,
**preprocess_kwargs: Any,
) -> "DataModule":
"""Creates a :class:`~flash.tabular.data.TabularData` object from the given CSV files.
Expand Down Expand Up @@ -479,6 +484,7 @@ def from_csv(
val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
sampler: The ``sampler`` to use for the ``train_dataloader``.
preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used
if ``preprocess = None``.
Expand Down Expand Up @@ -506,4 +512,6 @@ def from_csv(
val_split=val_split,
batch_size=batch_size,
num_workers=num_workers,
sampler=sampler,
**preprocess_kwargs,
)

0 comments on commit 0b014a9

Please sign in to comment.