Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Add sampler argument to tabular data #792

Merged
merged 3 commits into from
Sep 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed a bug where additional kwargs (e.g. sampler) passed to tabular data would be ignored ([#792](https://github.com/PyTorchLightning/lightning-flash/pull/792))


## [0.5.0] - 2021-09-07

Expand Down
10 changes: 9 additions & 1 deletion flash/tabular/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from io import StringIO
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union

import numpy as np
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch.utils.data.sampler import Sampler

from flash.core.classification import LabelsState
from flash.core.data.callback import BaseDataFetcher
Expand Down Expand Up @@ -344,6 +345,7 @@ def from_data_frame(
val_split: Optional[float] = None,
batch_size: int = 4,
num_workers: int = 0,
sampler: Optional[Type[Sampler]] = None,
**preprocess_kwargs: Any,
):
"""Creates a :class:`~flash.tabular.data.TabularData` object from the given data frames.
Expand Down Expand Up @@ -372,6 +374,7 @@ def from_data_frame(
val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
sampler: The ``sampler`` to use for the ``train_dataloader``.
preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used
if ``preprocess = None``.

Expand Down Expand Up @@ -420,6 +423,7 @@ def from_data_frame(
val_split=val_split,
batch_size=batch_size,
num_workers=num_workers,
sampler=sampler,
cat_cols=categorical_fields,
num_cols=numerical_fields,
target_col=target_fields,
Expand Down Expand Up @@ -451,6 +455,7 @@ def from_csv(
val_split: Optional[float] = None,
batch_size: int = 4,
num_workers: int = 0,
sampler: Optional[Type[Sampler]] = None,
**preprocess_kwargs: Any,
) -> "DataModule":
"""Creates a :class:`~flash.tabular.data.TabularData` object from the given CSV files.
Expand Down Expand Up @@ -479,6 +484,7 @@ def from_csv(
val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
sampler: The ``sampler`` to use for the ``train_dataloader``.
preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used
if ``preprocess = None``.

Expand Down Expand Up @@ -506,4 +512,6 @@ def from_csv(
val_split=val_split,
batch_size=batch_size,
num_workers=num_workers,
sampler=sampler,
**preprocess_kwargs,
)