diff --git a/flash/tabular/classification/data/data.py b/flash/tabular/classification/data/data.py index 5a02c0cd57..fdb566ac20 100644 --- a/flash/tabular/classification/data/data.py +++ b/flash/tabular/classification/data/data.py @@ -70,13 +70,13 @@ class TabularData(DataModule): def __init__( self, - train_df, + train_df: DataFrame, categorical_input: List, numerical_input: List, target: str, - valid_df=None, - test_df=None, - batch_size=2, + valid_df: Optional[DataFrame] = None, + test_df: Optional[DataFrame] = None, + batch_size: int = 2, num_workers: Optional[int] = None, ): dfs = [train_df] @@ -131,12 +131,12 @@ def num_features(self) -> int: @classmethod def from_df( cls, - train_df: pd.DataFrame, + train_df: DataFrame, target: str, categorical_input: List, numerical_input: List, - valid_df: pd.DataFrame = None, - test_df: pd.DataFrame = None, + valid_df: Optional[DataFrame] = None, + test_df: Optional[DataFrame] = None, batch_size: int = 8, num_workers: Optional[int] = None, val_size: float = None, @@ -192,7 +192,7 @@ def from_df( @classmethod def from_csv( cls, - train_csv, + train_csv: str, target: str, categorical_input: List, numerical_input: List, diff --git a/flash/tabular/classification/data/dataset.py b/flash/tabular/classification/data/dataset.py index 5fa7298ba7..da653f3549 100644 --- a/flash/tabular/classification/data/dataset.py +++ b/flash/tabular/classification/data/dataset.py @@ -129,7 +129,15 @@ def _dfs_to_samples(dfs, cat_cols, num_cols) -> list: class PandasDataset(Dataset): - def __init__(self, df, cat_cols, num_cols, target_col, regression=False, predict=False): + def __init__( + self, + df: DataFrame, + cat_cols: List, + num_cols: List, + target_col: str, + regression: bool = False, + predict: bool = False + ): self._num_samples = len(df) self.predict = predict cat_vars = _to_cat_vars_numpy(df, cat_cols)