Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Add sampler argument to tabular data (#792)
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris authored Sep 23, 2021
1 parent ca3870a commit 81b73ad
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed a bug where additional kwargs (e.g. sampler) passed to tabular data would be ignored ([#792](https://github.com/PyTorchLightning/lightning-flash/pull/792))


## [0.5.0] - 2021-09-07

Expand Down
10 changes: 9 additions & 1 deletion flash/tabular/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from io import StringIO
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union

import numpy as np
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch.utils.data.sampler import Sampler

from flash.core.classification import LabelsState
from flash.core.data.callback import BaseDataFetcher
Expand Down Expand Up @@ -344,6 +345,7 @@ def from_data_frame(
val_split: Optional[float] = None,
batch_size: int = 4,
num_workers: int = 0,
sampler: Optional[Type[Sampler]] = None,
**preprocess_kwargs: Any,
):
"""Creates a :class:`~flash.tabular.data.TabularData` object from the given data frames.
Expand Down Expand Up @@ -372,6 +374,7 @@ def from_data_frame(
val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
sampler: The ``sampler`` to use for the ``train_dataloader``.
preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used
if ``preprocess = None``.
Expand Down Expand Up @@ -420,6 +423,7 @@ def from_data_frame(
val_split=val_split,
batch_size=batch_size,
num_workers=num_workers,
sampler=sampler,
cat_cols=categorical_fields,
num_cols=numerical_fields,
target_col=target_fields,
Expand Down Expand Up @@ -451,6 +455,7 @@ def from_csv(
val_split: Optional[float] = None,
batch_size: int = 4,
num_workers: int = 0,
sampler: Optional[Type[Sampler]] = None,
**preprocess_kwargs: Any,
) -> "DataModule":
"""Creates a :class:`~flash.tabular.data.TabularData` object from the given CSV files.
Expand Down Expand Up @@ -479,6 +484,7 @@ def from_csv(
val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
sampler: The ``sampler`` to use for the ``train_dataloader``.
preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used
if ``preprocess = None``.
Expand Down Expand Up @@ -506,4 +512,6 @@ def from_csv(
val_split=val_split,
batch_size=batch_size,
num_workers=num_workers,
sampler=sampler,
**preprocess_kwargs,
)

0 comments on commit 81b73ad

Please sign in to comment.