Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions autoPyTorch/api/tabular_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from autoPyTorch.datasets.base_dataset import BaseDatasetPropertiesType
from autoPyTorch.datasets.resampling_strategy import (
HoldoutValTypes,
CrossValTypes,
ResamplingStrategies,
)
from autoPyTorch.datasets.tabular_dataset import TabularDataset
Expand Down Expand Up @@ -384,13 +383,6 @@ def search(
dataset_name=dataset_name
)

if not isinstance(self.resampling_strategy, (CrossValTypes, HoldoutValTypes)):
raise ValueError(
'Hyperparameter optimization requires a validation split. '
'Expected `self.resampling_strategy` to be either '
'(CrossValTypes, HoldoutValTypes), but got {}'.format(self.resampling_strategy)
)

return self._search(
dataset=self.dataset,
optimize_metric=optimize_metric,
Expand Down
8 changes: 0 additions & 8 deletions autoPyTorch/api/tabular_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from autoPyTorch.datasets.base_dataset import BaseDatasetPropertiesType
from autoPyTorch.datasets.resampling_strategy import (
HoldoutValTypes,
CrossValTypes,
ResamplingStrategies,
)
from autoPyTorch.datasets.tabular_dataset import TabularDataset
Expand Down Expand Up @@ -384,13 +383,6 @@ def search(
dataset_name=dataset_name
)

if not isinstance(self.resampling_strategy, (CrossValTypes, HoldoutValTypes)):
raise ValueError(
'Hyperparameter optimization requires a validation split. '
'Expected `self.resampling_strategy` to be either '
'(CrossValTypes, HoldoutValTypes), but got {}'.format(self.resampling_strategy)
)

return self._search(
dataset=self.dataset,
optimize_metric=optimize_metric,
Expand Down
128 changes: 71 additions & 57 deletions autoPyTorch/data/tabular_feature_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,13 @@
from sklearn.exceptions import NotFittedError
from sklearn.impute import SimpleImputer
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.preprocessing import OrdinalEncoder

from autoPyTorch.data.base_feature_validator import BaseFeatureValidator, SUPPORTED_FEAT_TYPES


def _create_column_transformer(
preprocessors: Dict[str, List[BaseEstimator]],
numerical_columns: List[str],
categorical_columns: List[str],
) -> ColumnTransformer:
"""
Expand All @@ -32,49 +31,36 @@ def _create_column_transformer(
Args:
preprocessors (Dict[str, List[BaseEstimator]]):
Dictionary containing list of numerical and categorical preprocessors.
numerical_columns (List[str]):
List of names of numerical columns
categorical_columns (List[str]):
List of names of categorical columns

Returns:
ColumnTransformer
"""

numerical_pipeline = 'drop'
categorical_pipeline = 'drop'
if len(numerical_columns) > 0:
numerical_pipeline = make_pipeline(*preprocessors['numerical'])
if len(categorical_columns) > 0:
categorical_pipeline = make_pipeline(*preprocessors['categorical'])
categorical_pipeline = make_pipeline(*preprocessors['categorical'])

return ColumnTransformer([
('categorical_pipeline', categorical_pipeline, categorical_columns),
('numerical_pipeline', numerical_pipeline, numerical_columns)],
remainder='drop'
('categorical_pipeline', categorical_pipeline, categorical_columns)],
remainder='passthrough'
)


def get_tabular_preprocessors() -> Dict[str, List[BaseEstimator]]:
"""
This function creates a Dictionary containing a list
of numerical and categorical preprocessors

Returns:
Dict[str, List[BaseEstimator]]
"""
preprocessors: Dict[str, List[BaseEstimator]] = dict()

# Categorical Preprocessors
onehot_encoder = OneHotEncoder(categories='auto', sparse=False, handle_unknown='ignore')
ordinal_encoder = OrdinalEncoder(handle_unknown='use_encoded_value',
unknown_value=-1)
categorical_imputer = SimpleImputer(strategy='constant', copy=False)

# Numerical Preprocessors
numerical_imputer = SimpleImputer(strategy='median', copy=False)
standard_scaler = StandardScaler(with_mean=True, with_std=True, copy=False)

preprocessors['categorical'] = [categorical_imputer, onehot_encoder]
preprocessors['numerical'] = [numerical_imputer, standard_scaler]
preprocessors['categorical'] = [categorical_imputer, ordinal_encoder]

return preprocessors

Expand Down Expand Up @@ -161,31 +147,47 @@ def _fit(

X = cast(pd.DataFrame, X)

self.all_nan_columns = set([column for column in X.columns if X[column].isna().all()])
all_nan_columns = X.columns[X.isna().all()]
for col in all_nan_columns:
X[col] = pd.to_numeric(X[col])

# Handle objects if possible
exist_object_columns = has_object_columns(X.dtypes.values)
if exist_object_columns:
X = self.infer_objects(X)

categorical_columns, numerical_columns, feat_type = self._get_columns_info(X)
self.dtypes = [dt.name for dt in X.dtypes] # Also note this change in self.dtypes
self.all_nan_columns = set(all_nan_columns)

self.enc_columns = categorical_columns
self.enc_columns, self.feat_type = self._get_columns_info(X)

preprocessors = get_tabular_preprocessors()
self.column_transformer = _create_column_transformer(
preprocessors=preprocessors,
numerical_columns=numerical_columns,
categorical_columns=categorical_columns,
)
if len(self.enc_columns) > 0:

# Mypy redefinition
assert self.column_transformer is not None
self.column_transformer.fit(X)
preprocessors = get_tabular_preprocessors()
self.column_transformer = _create_column_transformer(
preprocessors=preprocessors,
categorical_columns=self.enc_columns,
)

# The column transformer reorders the feature types
# therefore, we need to change the order of columns as well
# This means categorical columns are shifted to the left
# Mypy redefinition
assert self.column_transformer is not None
self.column_transformer.fit(X)

self.feat_type = sorted(
feat_type,
key=functools.cmp_to_key(self._comparator)
)
# The column transformer moves categorical columns before all numerical columns
# therefore, we need to sort categorical columns so that it complies this change

self.feat_type = sorted(
self.feat_type,
key=functools.cmp_to_key(self._comparator)
)

encoded_categories = self.column_transformer.\
named_transformers_['categorical_pipeline'].\
named_steps['ordinalencoder'].categories_
self.categories = [
list(range(len(cat)))
for cat in encoded_categories
]

# differently to categorical_columns and numerical_columns,
# this saves the index of the column.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The lines below will look better by this (len(enc_columns) > 0 for data containing categoricals, right?):

            num_numericals, num_categoricals = self.feat_type.count('numerical'), self.feat_type.count('categorical')
            if num_numericals + num_categoricals != len(self.feat_type):
                raise ValueError("Elements of feat_type must be either ['numerical', 'categorical']")

            self.categorical_columns = list(range(num_categoricals))
            self.numerical_columns = list(range(num_categoricals, num_categoricals + num_numericals))

Expand Down Expand Up @@ -265,6 +267,22 @@ def transform(
if hasattr(X, "iloc") and not scipy.sparse.issparse(X):
X = cast(Type[pd.DataFrame], X)

if self.all_nan_columns is None:
raise ValueError('_fit must be called before calling transform')

for col in list(self.all_nan_columns):
X[col] = np.nan
X[col] = pd.to_numeric(X[col])

if len(self.categorical_columns) > 0:
if self.column_transformer is None:
raise AttributeError("Expect column transformer to be built"
"if there are categorical columns")
categorical_columns = self.column_transformer.transformers_[0][-1]
for column in categorical_columns:
if X[column].isna().all():
X[column] = X[column].astype('object')

# Check the data here so we catch problems on new test data
self._check_data(X)

Expand All @@ -273,11 +291,6 @@ def transform(
# We need to convert the column in test data to
# object otherwise the test column is interpreted as float
if self.column_transformer is not None:
if len(self.categorical_columns) > 0:
categorical_columns = self.column_transformer.transformers_[0][-1]
for column in categorical_columns:
if X[column].isna().all():
X[column] = X[column].astype('object')
X = self.column_transformer.transform(X)

# Sparse related transformations
Expand Down Expand Up @@ -361,7 +374,6 @@ def _check_data(
self.column_order = column_order

dtypes = [dtype.name for dtype in X.dtypes]

diff_cols = X.columns[[s_dtype != dtype for s_dtype, dtype in zip(self.dtypes, dtypes)]]
if len(self.dtypes) == 0:
self.dtypes = dtypes
Expand All @@ -373,7 +385,7 @@ def _check_data(
def _get_columns_info(
self,
X: pd.DataFrame,
) -> Tuple[List[str], List[str], List[str]]:
) -> Tuple[List[str], List[str]]:
"""
Return the columns to be encoded from a pandas dataframe

Expand All @@ -392,15 +404,12 @@ def _get_columns_info(
"""

# Register if a column needs encoding
numerical_columns = []
categorical_columns = []
# Also, register the feature types for the estimator
feat_type = []

# Make sure each column is a valid type
for i, column in enumerate(X.columns):
if self.all_nan_columns is not None and column in self.all_nan_columns:
continue
Comment on lines -402 to -403
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do not we need this anymore?

column_dtype = self.dtypes[i]
err_msg = "Valid types are `numerical`, `categorical` or `boolean`, " \
"but input column {} has an invalid type `{}`.".format(column, column_dtype)
Expand All @@ -411,7 +420,6 @@ def _get_columns_info(
# TypeError: data type not understood in certain pandas types
elif is_numeric_dtype(column_dtype):
feat_type.append('numerical')
numerical_columns.append(column)
elif column_dtype == 'object':
# TODO verify how would this happen when we always convert the object dtypes to category
raise TypeError(
Expand All @@ -437,7 +445,7 @@ def _get_columns_info(
"before feeding it to AutoPyTorch.".format(err_msg)
)

return categorical_columns, numerical_columns, feat_type
return categorical_columns, feat_type

def list_to_pandas(
self,
Expand Down Expand Up @@ -517,11 +525,17 @@ def infer_objects(self, X: pd.DataFrame) -> pd.DataFrame:
self.logger.warning(f'Casting the column {key} to {dtype} caused the exception {e}')
pass
else:
# Calling for the first time to infer the categories
X = X.infer_objects()
for column, data_type in zip(X.columns, X.dtypes):
if not is_numeric_dtype(data_type):
X[column] = X[column].astype('category')
if len(self.dtypes) != 0:
# when train data has no object dtype, but test does
# we prioritise the datatype given in training data
for column, data_type in zip(X.columns, self.dtypes):
X[column] = X[column].astype(data_type)
else:
# Calling for the first time to infer the categories
X = X.infer_objects()
for column, data_type in zip(X.columns, X.dtypes):
if not is_numeric_dtype(data_type):
X[column] = X[column].astype('category')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if len(self.dtypes) != 0:
# when train data has no object dtype, but test does
# we prioritise the datatype given in training data
for column, data_type in zip(X.columns, self.dtypes):
X[column] = X[column].astype(data_type)
else:
# Calling for the first time to infer the categories
X = X.infer_objects()
for column, data_type in zip(X.columns, X.dtypes):
if not is_numeric_dtype(data_type):
X[column] = X[column].astype('category')
elif len(self.dtypes) != 0: # when train data has no object dtype, but test does
# we prioritise the datatype given in training data
for column, data_type in zip(X.columns, self.dtypes):
X[column] = X[column].astype(data_type)
else: # Calling for the first time to infer the categories
X = X.infer_objects()
for column, data_type in zip(X.columns, X.dtypes):
if not is_numeric_dtype(data_type):
X[column] = X[column].astype('category')

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these are just preferences on where to start the comment.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nono, it actually removed an indent level.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually, if you notice we are also saving the dtypes in self.object_dtype_mapping which should be done for both of the two conditions you moved back an indent level. So, I think its fine the way it is.

Copy link
Collaborator

@nabenabe0928 nabenabe0928 Feb 3, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah, I did not notice, but I also did not notice that we still have the same issue (which happens when we have a huge number of features) in this method.
Could you fix it?

        if hasattr(self, 'object_dtype_mapping'):
            # Mypy does not process the has attr. This dict is defined below
            try:
                X = X.astype(self.object_dtype_mapping)
            except Exception as e:
                self.logger.warning(f'Casting test data to data type in train data caused the exception {e}')
                pass
            return

        if len(self.dtypes) != 0:
            # when train data has no object dtype, but test does.  Prioritise the datatype given in training data
            dtype_dict = {col: dtype for col, dtype in zip(X.columns, self.dtypes)}
            X = X.astype(dtype_dict)
        else:
            # Calling for the first time to infer the categories
            X = X.infer_objects()
            dtype_dict = {col: 'category' for col, dtype in zip(X.columns, X.dtypes) if not is_numeric_dtype(dtype)}
            X = X.astype(dtype_dict)

        # only numerical attributes and categories
        self.object_dtype_mapping = {col: dtype for col, dtype in zip(X.columns, X.dtypes)}


# only numerical attributes and categories
self.object_dtype_mapping = {column: data_type for column, data_type in zip(X.columns, X.dtypes)}
Expand Down
3 changes: 1 addition & 2 deletions autoPyTorch/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,6 @@ def __init__(
self.holdout_validators: Dict[str, HoldOutFunc] = {}
self.no_resampling_validators: Dict[str, NoResamplingFunc] = {}
self.random_state = np.random.RandomState(seed=seed)
self.no_resampling_validators: Dict[str, NoResamplingFunc] = {}
self.shuffle = shuffle
self.resampling_strategy = resampling_strategy
self.resampling_strategy_args = resampling_strategy_args
Expand All @@ -145,7 +144,7 @@ def __init__(

# TODO: Look for a criteria to define small enough to preprocess
# False for the regularization cocktails initially
self.is_small_preprocess = False
# self.is_small_preprocess = False

# Make sure cross validation splits are created once
self.cross_validators = CrossValFuncs.get_cross_validators(*CrossValTypes)
Expand Down
7 changes: 0 additions & 7 deletions autoPyTorch/datasets/resampling_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,6 @@ def __call__(self, random_state: np.random.RandomState, val_share: float,
...


class NoResamplingFunc(Protocol):
def __call__(self,
random_state: np.random.RandomState,
indices: np.ndarray) -> np.ndarray:
...


class CrossValTypes(IntEnum):
"""The type of cross validation

Expand Down
Loading