diff --git a/flash/tabular/classification/data/data.py b/flash/tabular/classification/data/data.py index fdb566ac20..8d9977af22 100644 --- a/flash/tabular/classification/data/data.py +++ b/flash/tabular/classification/data/data.py @@ -71,9 +71,9 @@ class TabularData(DataModule): def __init__( self, train_df: DataFrame, - categorical_input: List, - numerical_input: List, target: str, + categorical_input: Optional[List] = None, + numerical_input: Optional[List] = None, valid_df: Optional[DataFrame] = None, test_df: Optional[DataFrame] = None, batch_size: int = 2, @@ -82,6 +82,12 @@ def __init__( dfs = [train_df] self._test_df = None + if categorical_input is None and numerical_input is None: + raise RuntimeError('Both `categorical_input` and `numerical_input` are None!') + + categorical_input = categorical_input if categorical_input is not None else [] + numerical_input = numerical_input if numerical_input is not None else [] + if valid_df is not None: dfs.append(valid_df) @@ -133,8 +139,8 @@ def from_df( cls, train_df: DataFrame, target: str, - categorical_input: List, - numerical_input: List, + categorical_input: Optional[List] = None, + numerical_input: Optional[List] = None, valid_df: Optional[DataFrame] = None, test_df: Optional[DataFrame] = None, batch_size: int = 8, @@ -194,8 +200,8 @@ def from_csv( cls, train_csv: str, target: str, - categorical_input: List, - numerical_input: List, + categorical_input: Optional[List] = None, + numerical_input: Optional[List] = None, valid_csv: Optional[str] = None, test_csv: Optional[str] = None, batch_size: int = 8, diff --git a/tests/tabular/data/test_data.py b/tests/tabular/data/test_data.py index 2f533e93e7..7ddbfeb5ea 100644 --- a/tests/tabular/data/test_data.py +++ b/tests/tabular/data/test_data.py @@ -16,6 +16,7 @@ import numpy as np import pandas as pd +import pytest from flash.tabular import TabularData from flash.tabular.classification.data.dataset import _categorize, _normalize @@ -169,3 +170,11 @@ def test_from_csv(tmpdir): assert cat.shape == (1, 1) assert num.shape == (1, 2) assert target.shape == (1, ) + + +def test_empty_inputs(): + train_df = TEST_DF_1.copy() + with pytest.raises(RuntimeError): + TabularData.from_df( + train_df, categorical_input=None, numerical_input=None, target="label", num_workers=0, batch_size=1 + )