From 4e22c93e576c7cc5b1418d2a2e94bdcd85adc279 Mon Sep 17 00:00:00 2001 From: nabenabe0928 Date: Thu, 3 Mar 2022 05:20:18 +0900 Subject: [PATCH 01/10] [refactor] Rename transformed_columns --> enc_columns --- autoPyTorch/data/base_feature_validator.py | 6 ++--- autoPyTorch/data/tabular_feature_validator.py | 22 +++++++++---------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/autoPyTorch/data/base_feature_validator.py b/autoPyTorch/data/base_feature_validator.py index 11c6cf577..a6cfbb755 100644 --- a/autoPyTorch/data/base_feature_validator.py +++ b/autoPyTorch/data/base_feature_validator.py @@ -27,8 +27,8 @@ class BaseFeatureValidator(BaseEstimator): column_transformer (Optional[BaseEstimator]) Host a encoder object if the data requires transformation (for example, if provided a categorical column in a pandas DataFrame) - transformed_columns (List[str]) - List of columns that were encoded. + enc_columns (Optional[List[str]]): + The list of column names that should be encoded. """ def __init__( self, @@ -41,7 +41,7 @@ def __init__( self.column_order: List[str] = [] self.column_transformer: Optional[BaseEstimator] = None - self.transformed_columns: List[str] = [] + self.enc_columns: List[str] = [] self.logger: Union[ PicklableClientLogger, logging.Logger diff --git a/autoPyTorch/data/tabular_feature_validator.py b/autoPyTorch/data/tabular_feature_validator.py index 3e8c316b0..256261edb 100644 --- a/autoPyTorch/data/tabular_feature_validator.py +++ b/autoPyTorch/data/tabular_feature_validator.py @@ -86,7 +86,7 @@ class TabularFeatureValidator(BaseFeatureValidator): List for which an element at each index is a list containing the categories for the respective categorical column. - transformed_columns (List[str]) + enc_columns (List[str]) List of columns that were transformed. column_transformer (Optional[BaseEstimator]) Hosts an imputer and an encoder object if the data @@ -175,16 +175,16 @@ def _fit( if not X.select_dtypes(include='object').empty: X = self.infer_objects(X) - self.transformed_columns, self.feat_type = self._get_columns_to_encode(X) + self.enc_columns, self.feat_type = self._get_columns_to_encode(X) assert self.feat_type is not None - if len(self.transformed_columns) > 0: + if len(self.enc_columns) > 0: preprocessors = get_tabular_preprocessors() self.column_transformer = _create_column_transformer( preprocessors=preprocessors, - categorical_columns=self.transformed_columns, + categorical_columns=self.enc_columns, ) # Mypy redefinition @@ -374,7 +374,7 @@ def _check_data( # Define the column to be encoded here as the feature validator is fitted once # per estimator - self.transformed_columns, self.feat_type = self._get_columns_to_encode(X) + self.enc_columns, self.feat_type = self._get_columns_to_encode(X) column_order = [column for column in X.columns] if len(self.column_order) > 0: @@ -412,17 +412,17 @@ def _get_columns_to_encode( checks) and an encoder fitted in the case the data needs encoding Returns: - transformed_columns (List[str]): + enc_columns (List[str]): Columns to encode, if any feat_type: Type of each column numerical/categorical """ - if len(self.transformed_columns) > 0 and self.feat_type is not None: - return self.transformed_columns, self.feat_type + if len(self.enc_columns) > 0 and self.feat_type is not None: + return self.enc_columns, self.feat_type # Register if a column needs encoding - transformed_columns = [] + enc_columns = [] # Also, register the feature types for the estimator feat_type = [] @@ -431,7 +431,7 @@ def _get_columns_to_encode( for i, column in enumerate(X.columns): if X[column].dtype.name in ['category', 'bool']: - transformed_columns.append(column) + enc_columns.append(column) feat_type.append('categorical') # Move away from np.issubdtype as it causes # TypeError: data type not understood in certain pandas types @@ -473,7 +473,7 @@ def _get_columns_to_encode( ) else: feat_type.append('numerical') - return transformed_columns, feat_type + return enc_columns, feat_type def list_to_dataframe( self, From c286da730cd759b22e0bc24baac6e3923639ee6a Mon Sep 17 00:00:00 2001 From: nabenabe0928 Date: Thu, 3 Mar 2022 05:27:56 +0900 Subject: [PATCH 02/10] [refactor] Rename list_to_dataframe --> list_to_pandas --- autoPyTorch/data/base_feature_validator.py | 3 +- autoPyTorch/data/tabular_feature_validator.py | 40 ++++++------------- 2 files changed, 15 insertions(+), 28 deletions(-) diff --git a/autoPyTorch/data/base_feature_validator.py b/autoPyTorch/data/base_feature_validator.py index a6cfbb755..2b5183550 100644 --- a/autoPyTorch/data/base_feature_validator.py +++ b/autoPyTorch/data/base_feature_validator.py @@ -75,7 +75,8 @@ def fit( # If a list was provided, it will be converted to pandas if isinstance(X_train, list): - X_train, X_test = self.list_to_dataframe(X_train, X_test) + X_train = self.list_to_pandas(X_train) + X_test = self.list_to_pandas(X_test) if X_test is not None else None self._check_data(X_train) diff --git a/autoPyTorch/data/tabular_feature_validator.py b/autoPyTorch/data/tabular_feature_validator.py index 256261edb..2c88cd36c 100644 --- a/autoPyTorch/data/tabular_feature_validator.py +++ b/autoPyTorch/data/tabular_feature_validator.py @@ -241,7 +241,7 @@ def transform( # If a list was provided, it will be converted to pandas if isinstance(X, list): - X, _ = self.list_to_dataframe(X) + X = self.list_to_pandas(X) if isinstance(X, np.ndarray): X = self.numpy_array_to_pandas(X) @@ -475,42 +475,28 @@ def _get_columns_to_encode( feat_type.append('numerical') return enc_columns, feat_type - def list_to_dataframe( - self, - X_train: SupportedFeatTypes, - X_test: Optional[SupportedFeatTypes] = None, - ) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: + def list_to_pandas(self, X: SupportedFeatTypes) -> pd.DataFrame: """ - Converts a list to a pandas DataFrame. In this process, column types are inferred. - - If test data is provided, we proactively match it to train data + Convert a list to a pandas DataFrame. In this process, column types are inferred. Args: - X_train (SupportedFeatTypes): + X (SupportedFeatTypes): A set of features that are going to be validated (type and dimensionality - checks) and a encoder fitted in the case the data needs encoding - X_test (Optional[SupportedFeatTypes]): - A hold out set of data used for checking + checks) and an encoder fitted in the case the data needs encoding Returns: pd.DataFrame: - transformed train data from list to pandas DataFrame - pd.DataFrame: - transformed test data from list to pandas DataFrame + transformed data from list to pandas DataFrame """ # If a list was provided, it will be converted to pandas - X_train = pd.DataFrame(data=X_train).infer_objects() - self.logger.warning("The provided feature types to AutoPyTorch are of type list." - "Features have been interpreted as: {}".format([(col, t) for col, t in - zip(X_train.columns, X_train.dtypes)])) - if X_test is not None: - if not isinstance(X_test, list): - self.logger.warning("Train features are a list while the provided test data" - "is {}. X_test will be casted as DataFrame.".format(type(X_test)) - ) - X_test = pd.DataFrame(data=X_test).infer_objects() - return X_train, X_test + X = pd.DataFrame(data=X).infer_objects() + data_info = [(col, t) for col, t in zip(X.columns, X.dtypes)] + self.logger.warning( + "The provided feature types to AutoPyTorch are list." + f"Features have been interpreted as: {data_info}" + ) + return X def numpy_array_to_pandas( self, From 2266d415cef70144ae1db07e74db66f683dd12e8 Mon Sep 17 00:00:00 2001 From: nabenabe0928 Date: Thu, 3 Mar 2022 05:31:54 +0900 Subject: [PATCH 03/10] [refactor] Rename numpy_array_to_pandas --> numpy_to_pandas --- autoPyTorch/data/tabular_feature_validator.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/autoPyTorch/data/tabular_feature_validator.py b/autoPyTorch/data/tabular_feature_validator.py index 2c88cd36c..53727efc5 100644 --- a/autoPyTorch/data/tabular_feature_validator.py +++ b/autoPyTorch/data/tabular_feature_validator.py @@ -154,7 +154,7 @@ def _fit( # The final output of a validator is a numpy array. But pandas # gives us information about the column dtype if isinstance(X, np.ndarray): - X = self.numpy_array_to_pandas(X) + X = self.numpy_to_pandas(X) if ispandas(X) and not issparse(X): X = cast(pd.DataFrame, X) @@ -244,7 +244,7 @@ def transform( X = self.list_to_pandas(X) if isinstance(X, np.ndarray): - X = self.numpy_array_to_pandas(X) + X = self.numpy_to_pandas(X) if ispandas(X) and not issparse(X): if np.any(pd.isnull(X)): @@ -498,7 +498,7 @@ def list_to_pandas(self, X: SupportedFeatTypes) -> pd.DataFrame: ) return X - def numpy_array_to_pandas( + def numpy_to_pandas( self, X: np.ndarray, ) -> pd.DataFrame: From 080fe95491b7769d7b2595f381732801a3a414d9 Mon Sep 17 00:00:00 2001 From: nabenabe0928 Date: Thu, 3 Mar 2022 05:56:20 +0900 Subject: [PATCH 04/10] [refactor] Separate errors due to invalid types in pandas --- autoPyTorch/data/tabular_feature_validator.py | 75 +++++++------------ test/test_data/test_feature_validator.py | 9 +-- 2 files changed, 33 insertions(+), 51 deletions(-) diff --git a/autoPyTorch/data/tabular_feature_validator.py b/autoPyTorch/data/tabular_feature_validator.py index 53727efc5..1cdce2963 100644 --- a/autoPyTorch/data/tabular_feature_validator.py +++ b/autoPyTorch/data/tabular_feature_validator.py @@ -74,6 +74,28 @@ def get_tabular_preprocessors() -> Dict[str, List[BaseEstimator]]: return preprocessors +def _error_due_to_unsupported_column(X: pd.DataFrame, column: str) -> None: + # Move away from np.issubdtype as it causes + # TypeError: data type not understood in certain pandas types + def _generate_error_message_prefix(type_name: str, proc_type: Optional[str] = None) -> str: + msg1 = f"column `{column}` has an invalid type `{type_name}`. " + msg2 = "Cast it to a numerical type, category type or bool type by astype method. " + msg3 = f"The following link might help you to know {proc_type} processing: " + return msg1 + msg2 + ("" if proc_type is None else msg3) + + dtype = X[column].dtype + if dtype.name == 'object': + err_msg = _generate_error_message_prefix(type_name="object", proc_type="string") + url = "https://scikit-learn.org/stable/tutorial/text_analytics/working_with_text_data.html" + raise TypeError(f"{err_msg}{url}") + elif pd.core.dtypes.common.is_datetime_or_timedelta_dtype(dtype): + err_msg = _generate_error_message_prefix(type_name="time and/or date datatype", proc_type="datetime") + raise TypeError(f"{err_msg}https://stats.stackexchange.com/questions/311494/") + else: + err_msg = _generate_error_message_prefix(type_name=dtype.name) + raise TypeError(err_msg) + + class TabularFeatureValidator(BaseFeatureValidator): """ A subclass of `BaseFeatureValidator` made for tabular data. @@ -399,10 +421,7 @@ def _check_data( else: self.dtypes = dtypes - def _get_columns_to_encode( - self, - X: pd.DataFrame, - ) -> Tuple[List[str], List[str]]: + def _get_columns_to_encode(self, X: pd.DataFrame) -> Tuple[List[str], List[str]]: """ Return the columns to be encoded from a pandas dataframe @@ -428,51 +447,15 @@ def _get_columns_to_encode( feat_type = [] # Make sure each column is a valid type - for i, column in enumerate(X.columns): - if X[column].dtype.name in ['category', 'bool']: - + for dtype, column in zip(X.dtypes, X.columns): + if dtype.name in ['category', 'bool']: enc_columns.append(column) feat_type.append('categorical') - # Move away from np.issubdtype as it causes - # TypeError: data type not understood in certain pandas types - elif not is_numeric_dtype(X[column]): - if X[column].dtype.name == 'object': - raise ValueError( - "Input Column {} has invalid type object. " - "Cast it to a valid dtype before using it in AutoPyTorch. " - "Valid types are numerical, categorical or boolean. " - "You can cast it to a valid dtype using " - "pandas.Series.astype ." - "If working with string objects, the following " - "tutorial illustrates how to work with text data: " - "https://scikit-learn.org/stable/tutorial/text_analytics/working_with_text_data.html".format( - # noqa: E501 - column, - ) - ) - elif pd.core.dtypes.common.is_datetime_or_timedelta_dtype( - X[column].dtype - ): - raise ValueError( - "AutoPyTorch does not support time and/or date datatype as given " - "in column {}. Please convert the time information to a numerical value " - "first. One example on how to do this can be found on " - "https://stats.stackexchange.com/questions/311494/".format( - column, - ) - ) - else: - raise ValueError( - "Input Column {} has unsupported dtype {}. " - "Supported column types are categorical/bool/numerical dtypes. " - "Make sure your data is formatted in a correct way, " - "before feeding it to AutoPyTorch.".format( - column, - X[column].dtype.name, - ) - ) - else: + elif is_numeric_dtype(dtype): feat_type.append('numerical') + else: + _error_due_to_unsupported_column(X, column) + return enc_columns, feat_type def list_to_pandas(self, X: SupportedFeatTypes) -> pd.DataFrame: diff --git a/test/test_data/test_feature_validator.py b/test/test_data/test_feature_validator.py index 3d352d765..eea974ee8 100644 --- a/test/test_data/test_feature_validator.py +++ b/test/test_data/test_feature_validator.py @@ -328,13 +328,12 @@ def test_features_unsupported_calls_are_raised(): expected """ validator = TabularFeatureValidator() - with pytest.raises(ValueError, match=r"AutoPyTorch does not support time"): - validator.fit( - pd.DataFrame({'datetime': [pd.Timestamp('20180310')]}) - ) + #with pytest.raises(TypeError, match=r"invalid type `time and/or date datatype`."): + with pytest.raises(TypeError, match=r"invalid type `time and/or date datatype`."): + validator.fit(pd.DataFrame({'datetime': [pd.Timestamp('20180310')]})) with pytest.raises(ValueError, match=r"AutoPyTorch only supports.*yet, the provided input"): validator.fit({'input1': 1, 'input2': 2}) - with pytest.raises(ValueError, match=r"has unsupported dtype string"): + with pytest.raises(TypeError, match=r"invalid type `string`."): validator.fit(pd.DataFrame([{'A': 1, 'B': 2}], dtype='string')) with pytest.raises(ValueError, match=r"The feature dimensionality of the train and test"): validator.fit(X_train=np.array([[1, 2, 3], [4, 5, 6]]), From c3e0fa0f5df2f3579c5aecc394131bf3fd6f67ca Mon Sep 17 00:00:00 2001 From: nabenabe0928 Date: Thu, 3 Mar 2022 06:34:25 +0900 Subject: [PATCH 05/10] [refactor] Separate convert all nan columns to numeric --- autoPyTorch/data/tabular_feature_validator.py | 99 +++++++++++-------- test/test_data/test_feature_validator.py | 1 - 2 files changed, 56 insertions(+), 44 deletions(-) diff --git a/autoPyTorch/data/tabular_feature_validator.py b/autoPyTorch/data/tabular_feature_validator.py index 1cdce2963..a8b39f464 100644 --- a/autoPyTorch/data/tabular_feature_validator.py +++ b/autoPyTorch/data/tabular_feature_validator.py @@ -154,10 +154,42 @@ def _comparator(cmp1: str, cmp2: str) -> int: idx1, idx2 = choices.index(cmp1), choices.index(cmp2) return idx1 - idx2 - def _fit( - self, - X: SupportedFeatTypes, - ) -> BaseEstimator: + def _convert_all_nan_columns_to_numeric(self, X: pd.DataFrame, fit: bool = False) -> pd.DataFrame: + """ + Convert columns whose values were all nan in the training dataset to numeric. + + Args: + X (pd.DataFrame): + The data to transform. + fit (bool): + Whether this call is the fit to X or the transform using pre-fitted transformer. + """ + if not fit and self.all_nan_columns is None: + raise ValueError('_fit must be called before calling transform') + + if fit: + all_nan_columns = X.columns[X.isna().all()] + else: + assert self.all_nan_columns is not None + all_nan_columns = list(self.all_nan_columns) + + for col in all_nan_columns: + X[col] = np.nan + X[col] = pd.to_numeric(X[col]) + if len(self.dtypes): + self.dtypes[list(X.columns).index(col)] = X[col].dtype + + if has_object_columns(X.dtypes.values): + X = self.infer_objects(X) + + if fit: + # TODO: Check how to integrate below + # self.dtypes = [dt.name for dt in X.dtypes] + self.all_nan_columns = set(all_nan_columns) + + return X + + def _fit(self, X: SupportedFeatTypes) -> BaseEstimator: """ In case input data is a pandas DataFrame, this utility encodes the user provided features (from categorical for example) to a numerical value that further stages @@ -180,23 +212,7 @@ def _fit( if ispandas(X) and not issparse(X): X = cast(pd.DataFrame, X) - # Treat a column with all instances a NaN as numerical - # This will prevent doing encoding to a categorical column made completely - # out of nan values -- which will trigger a fail, as encoding is not supported - # with nan values. - # Columns that are completely made of NaN values are provided to the pipeline - # so that later stages decide how to handle them - if np.any(pd.isnull(X)): - for column in X.columns: - if X[column].isna().all(): - X[column] = pd.to_numeric(X[column]) - # Also note this change in self.dtypes - if len(self.dtypes) != 0: - self.dtypes[list(X.columns).index(column)] = X[column].dtype - - if not X.select_dtypes(include='object').empty: - X = self.infer_objects(X) - + X = self._convert_all_nan_columns_to_numeric(X, fit=True) self.enc_columns, self.feat_type = self._get_columns_to_encode(X) assert self.feat_type is not None @@ -241,10 +257,7 @@ def _fit( self.num_features = np.shape(X)[1] return self - def transform( - self, - X: SupportedFeatTypes, - ) -> Union[np.ndarray, spmatrix, pd.DataFrame]: + def transform(self, X: SupportedFeatTypes) -> Union[np.ndarray, spmatrix, pd.DataFrame]: """ Validates and fit a categorical encoder (if needed) to the features. The supported data types are List, numpy arrays and pandas DataFrames. @@ -264,19 +277,11 @@ def transform( # If a list was provided, it will be converted to pandas if isinstance(X, list): X = self.list_to_pandas(X) - - if isinstance(X, np.ndarray): + elif isinstance(X, np.ndarray): X = self.numpy_to_pandas(X) if ispandas(X) and not issparse(X): - if np.any(pd.isnull(X)): - for column in X.columns: - if X[column].isna().all(): - X[column] = pd.to_numeric(X[column]) - - # Also remove the object dtype for new data - if not X.select_dtypes(include='object').empty: - X = self.infer_objects(X) + X = self._convert_all_nan_columns_to_numeric(X) # Check the data here so we catch problems on new test data self._check_data(X) @@ -344,10 +349,7 @@ def _compress_dataset(self, X: DatasetCompressionInputType) -> DatasetCompressio self._reduced_dtype = dict(X.dtypes) if is_dataframe else X.dtype return X - def _check_data( - self, - X: SupportedFeatTypes, - ) -> None: + def _check_data(self, X: SupportedFeatTypes) -> None: """ Feature dimensionality and data type checks @@ -481,10 +483,7 @@ def list_to_pandas(self, X: SupportedFeatTypes) -> pd.DataFrame: ) return X - def numpy_to_pandas( - self, - X: np.ndarray, - ) -> pd.DataFrame: + def numpy_to_pandas(self, X: np.ndarray) -> pd.DataFrame: """ Converts a numpy array to pandas for type inference @@ -533,3 +532,17 @@ def infer_objects(self, X: pd.DataFrame) -> pd.DataFrame: self.object_dtype_mapping = {column: X[column].dtype for column in X.columns} self.logger.debug(f"Infer Objects: {self.object_dtype_mapping}") return X + + +def has_object_columns(feature_types: pd.Series) -> bool: + """ + Indicate whether on a Series of dtypes for a Pandas DataFrame + there exists one or more object columns. + Args: + feature_types (pd.Series): The feature types for a DataFrame. + Returns: + bool: + True if the DataFrame dtypes contain an object column, False + otherwise. + """ + return np.dtype('O') in feature_types diff --git a/test/test_data/test_feature_validator.py b/test/test_data/test_feature_validator.py index eea974ee8..baf96a719 100644 --- a/test/test_data/test_feature_validator.py +++ b/test/test_data/test_feature_validator.py @@ -328,7 +328,6 @@ def test_features_unsupported_calls_are_raised(): expected """ validator = TabularFeatureValidator() - #with pytest.raises(TypeError, match=r"invalid type `time and/or date datatype`."): with pytest.raises(TypeError, match=r"invalid type `time and/or date datatype`."): validator.fit(pd.DataFrame({'datetime': [pd.Timestamp('20180310')]})) with pytest.raises(ValueError, match=r"AutoPyTorch only supports.*yet, the provided input"): From 2394600ff35170b50cf1c13209d64668bb9ef70f Mon Sep 17 00:00:00 2001 From: nabenabe0928 Date: Thu, 3 Mar 2022 06:44:22 +0900 Subject: [PATCH 06/10] [refactor] Separate some processes --- autoPyTorch/data/tabular_feature_validator.py | 158 ++++++++++-------- 1 file changed, 90 insertions(+), 68 deletions(-) diff --git a/autoPyTorch/data/tabular_feature_validator.py b/autoPyTorch/data/tabular_feature_validator.py index a8b39f464..c0b1cb1e7 100644 --- a/autoPyTorch/data/tabular_feature_validator.py +++ b/autoPyTorch/data/tabular_feature_validator.py @@ -189,6 +189,29 @@ def _convert_all_nan_columns_to_numeric(self, X: pd.DataFrame, fit: bool = False return X + def _encode_categories(self, X: pd.DataFrame) -> None: + preprocessors = get_tabular_preprocessors() + self.column_transformer = _create_column_transformer( + preprocessors=preprocessors, + categorical_columns=self.enc_columns, + ) + + assert self.column_transformer is not None # Mypy redefinition + self.column_transformer.fit(X) + + # The column transformer moves categoricals to the left side + self.feat_type = sorted(self.feat_type, key=functools.cmp_to_key(self._comparator)) + + encoded_categories = self.column_transformer.\ + named_transformers_['categorical_pipeline'].\ + named_steps['ordinalencoder'].categories_ + + # An ordinal encoder for each categorical columns + self.categories = [ + list(range(len(cat))) + for cat in encoded_categories + ] + def _fit(self, X: SupportedFeatTypes) -> BaseEstimator: """ In case input data is a pandas DataFrame, this utility encodes the user provided @@ -216,36 +239,8 @@ def _fit(self, X: SupportedFeatTypes) -> BaseEstimator: self.enc_columns, self.feat_type = self._get_columns_to_encode(X) assert self.feat_type is not None - if len(self.enc_columns) > 0: - - preprocessors = get_tabular_preprocessors() - self.column_transformer = _create_column_transformer( - preprocessors=preprocessors, - categorical_columns=self.enc_columns, - ) - - # Mypy redefinition - assert self.column_transformer is not None - self.column_transformer.fit(X) - - # The column transformer reorders the feature types - # therefore, we need to change the order of columns as well - # This means categorical columns are shifted to the left - self.feat_type = sorted( - self.feat_type, - key=functools.cmp_to_key(self._comparator) - ) - - encoded_categories = self.column_transformer.\ - named_transformers_['categorical_pipeline'].\ - named_steps['ordinalencoder'].categories_ - self.categories = [ - # We fit an ordinal encoder, where all categorical - # columns are shifted to the left - list(range(len(cat))) - for cat in encoded_categories - ] + self._encode_categories(X) for i, type_ in enumerate(self.feat_type): if 'numerical' in type_: @@ -253,7 +248,6 @@ def _fit(self, X: SupportedFeatTypes) -> BaseEstimator: else: self.categorical_columns.append(i) - # Lastly, store the number of features self.num_features = np.shape(X)[1] return self @@ -270,6 +264,41 @@ def transform(self, X: SupportedFeatTypes) -> Union[np.ndarray, spmatrix, pd.Dat Return: np.ndarray: The transformed array + + Note: + The default transform performs the folloing: + * simple imputation for both + * scaling for numerical + * one-hot encoding for categorical + For example, here is a simple case + of which all the columns are categorical. + data = [ + {'A': 1, 'B': np.nan, 'C': np.nan}, + {'A': np.nan, 'B': 3, 'C': np.nan}, + {'A': 2, 'B': np.nan, 'C': np.nan} + ] + and suppose all the columns are categorical, + then + * `A` in {np.nan, 1, 2} + * `B` in {np.nan, 3} + * `C` in {np.nan} <=== it will be dropped. + + So in the column A, + * np.nan ==> [1, 0, 0] (always the index 0) + * 1 ==> [0, 1, 0] + * 2 ==> [0, 0, 1] + in the column B, + * np.nan ==> [1, 0] + * 3 ==> [0, 1] + Therefore, by concatenating, + * {'A': 1, 'B': np.nan, 'C': np.nan} ==> [0, 1, 0, 1, 0] + * {'A': np.nan, 'B': 3, 'C': np.nan} ==> [1, 0, 0, 0, 1] + * {'A': 2, 'B': np.nan, 'C': np.nan} ==> [0, 0, 1, 1, 0] + ==> [ + [0, 1, 0, 1, 0], + [1, 0, 0, 0, 1], + [0, 0, 1, 1, 0] + ] """ if not self._is_fitted: raise NotFittedError("Cannot call transform on a validator that is not fitted") @@ -288,14 +317,6 @@ def transform(self, X: SupportedFeatTypes) -> Union[np.ndarray, spmatrix, pd.Dat # Pandas related transformations if ispandas(X) and self.column_transformer is not None: - if np.any(pd.isnull(X)): - # After above check it means that if there is a NaN - # the whole column must be NaN - # Make sure it is numerical and let the pipeline handle it - for column in X.columns: - if X[column].isna().all(): - X[column] = pd.to_numeric(X[column]) - X = self.column_transformer.transform(X) # Sparse related transformations @@ -304,17 +325,15 @@ def transform(self, X: SupportedFeatTypes) -> Union[np.ndarray, spmatrix, pd.Dat X.sort_indices() try: - X = sklearn.utils.check_array( - X, - force_all_finite=False, - accept_sparse='csr' - ) + X = sklearn.utils.check_array(X, force_all_finite=False, accept_sparse='csr') except Exception as e: - self.logger.exception(f"Conversion failed for input {X.dtypes} {X}" - "This means AutoPyTorch was not able to properly " - "Extract the dtypes of the provided input features. " - "Please try to manually cast it to a supported " - "numerical or categorical values.") + self.logger.exception( + f"Conversion failed for input {X.dtypes} {X}" + "This means AutoPyTorch was not able to properly " + "Extract the dtypes of the provided input features. " + "Please try to manually cast it to a supported " + "numerical or categorical values." + ) raise e X = self._compress_dataset(X) @@ -328,7 +347,6 @@ def _compress_dataset(self, X: DatasetCompressionInputType) -> DatasetCompressio the testing data is converted to the same dtype as the training data. - Args: X (DatasetCompressionInputType): Dataset @@ -510,27 +528,31 @@ def infer_objects(self, X: pd.DataFrame) -> pd.DataFrame: pd.DataFrame """ if hasattr(self, 'object_dtype_mapping'): - # Mypy does not process the has attr. This dict is defined below - for key, dtype in self.object_dtype_mapping.items(): # type: ignore[has-type] - if 'int' in dtype.name: - # In the case train data was interpreted as int - # and test data was interpreted as float, because of 0.0 - # for example, honor training data - X[key] = X[key].applymap(np.int64) - else: - try: - X[key] = X[key].astype(dtype.name) - except Exception as e: - # Try inference if possible - self.logger.warning(f"Tried to cast column {key} to {dtype} caused {e}") - pass + # honor the training data types + try: + # Mypy does not process the has attr. + X = X.astype(self.object_dtype_mapping) # type: ignore[has-type] + except Exception as e: + # Try inference if possible + self.logger.warning(f'Casting the columns to training dtypes ' # type: ignore[has-type] + f'{self.object_dtype_mapping} caused the exception {e}') + pass else: - X = X.infer_objects() - for column in X.columns: - if not is_numeric_dtype(X[column]): - X[column] = X[column].astype('category') - self.object_dtype_mapping = {column: X[column].dtype for column in X.columns} + if len(self.dtypes) != 0: + # when train data has no object dtype, but test does + # we prioritise the datatype given in training data + dtype_dict = {col: dtype for col, dtype in zip(X.columns, self.dtypes)} + X = X.astype(dtype_dict) + else: + # Calling for the first time to infer the categories + X = X.infer_objects() + dtype_dict = {col: 'category' for col, dtype in zip(X.columns, X.dtypes) if not is_numeric_dtype(dtype)} + X = X.astype(dtype_dict) + # only numerical attributes and categories + self.object_dtype_mapping = {column: data_type for column, data_type in zip(X.columns, X.dtypes)} + self.logger.debug(f"Infer Objects: {self.object_dtype_mapping}") + return X From 3216638a7d02c2dba7426085edfac80b6f85c75c Mon Sep 17 00:00:00 2001 From: nabenabe0928 Date: Thu, 3 Mar 2022 07:16:40 +0900 Subject: [PATCH 07/10] [refactor] Separate error handlings for the readability --- autoPyTorch/data/tabular_feature_validator.py | 95 ++++++++----------- test/test_data/test_feature_validator.py | 9 +- 2 files changed, 44 insertions(+), 60 deletions(-) diff --git a/autoPyTorch/data/tabular_feature_validator.py b/autoPyTorch/data/tabular_feature_validator.py index c0b1cb1e7..9fcd7d39f 100644 --- a/autoPyTorch/data/tabular_feature_validator.py +++ b/autoPyTorch/data/tabular_feature_validator.py @@ -367,6 +367,28 @@ def _compress_dataset(self, X: DatasetCompressionInputType) -> DatasetCompressio self._reduced_dtype = dict(X.dtypes) if is_dataframe else X.dtype return X + def _check_dataframe(self, X: pd.DataFrame) -> None: + err_msg = " of the features must be identical before/after fit(), " + err_msg += "but different between training and test datasets:\n" + + if has_object_columns(X.dtypes.values): + X = self.infer_objects(X) + + # Define the column to be encoded as the feature validator is fitted once per estimator + self.enc_columns, self.feat_type = self._get_columns_to_encode(X) + + column_order = [column for column in X.columns] + if len(self.column_order) == 0: + self.column_order = column_order + elif self.column_order != column_order: + raise ValueError(f"The column order{err_msg}train: {self.column_order}\ntest: {column_order}") + + dtypes = [dtype.name for dtype in X.dtypes] + if len(self.dtypes) == 0: + self.dtypes = dtypes + elif self.dtypes != dtypes: + raise ValueError(f"The dtypes{err_msg}train: {self.dtypes}\ntest: {dtypes}") + def _check_data(self, X: SupportedFeatTypes) -> None: """ Feature dimensionality and data type checks @@ -378,68 +400,29 @@ def _check_data(self, X: SupportedFeatTypes) -> None: """ if not isinstance(X, (np.ndarray, pd.DataFrame)) and not issparse(X): - raise ValueError("AutoPyTorch only supports Numpy arrays, Pandas DataFrames," - " scipy sparse and Python Lists, yet, the provided input is" - " of type {}".format(type(X)) - ) + raise TypeError( + "AutoPyTorch only supports numpy.ndarray, pandas.DataFrame," + f" scipy.sparse and List, but got {type(X)}" + ) if self.data_type is None: self.data_type = type(X) if self.data_type != type(X): - self.logger.warning("AutoPyTorch previously received features of type %s " - "yet the current features have type %s. Changing the dtype " - "of inputs to an estimator might cause problems" % ( - str(self.data_type), - str(type(X)), - ), - ) - - # Do not support category/string numpy data. Only numbers - if hasattr(X, "dtype"): - if not np.issubdtype(X.dtype.type, np.number): # type: ignore[union-attr] - raise ValueError( - "When providing a numpy array to AutoPyTorch, the only valid " - "dtypes are numerical ones. The provided data type {} is not supported." - "".format( - X.dtype.type, # type: ignore[union-attr] - ) - ) - - # Then for Pandas, we do not support Nan in categorical columns - if ispandas(X): - # If entered here, we have a pandas dataframe - X = cast(pd.DataFrame, X) - - # Handle objects if possible - if not X.select_dtypes(include='object').empty: - X = self.infer_objects(X) + self.logger.warning( + f"AutoPyTorch previously received features of type {str(self.data_type)}, " + f"but got type {str(type(X))} in the current features. This change might cause problems" + ) - # Define the column to be encoded here as the feature validator is fitted once - # per estimator - self.enc_columns, self.feat_type = self._get_columns_to_encode(X) + if ispandas(X): # For pandas, no support of nan in categorical cols + X = cast(pd.DataFrame, X) + self._check_dataframe(X) - column_order = [column for column in X.columns] - if len(self.column_order) > 0: - if self.column_order != column_order: - raise ValueError("Changing the column order of the features after fit() is " - "not supported. Fit() method was called with " - "{} whereas the new features have {} as type".format(self.column_order, - column_order,) - ) - else: - self.column_order = column_order - - dtypes = [dtype.name for dtype in X.dtypes] - if len(self.dtypes) > 0: - if self.dtypes != dtypes: - raise ValueError("Changing the dtype of the features after fit() is " - "not supported. Fit() method was called with " - "{} whereas the new features have {} as type".format(self.dtypes, - dtypes, - ) - ) - else: - self.dtypes = dtypes + # For ndarray, no support of category/string + if isinstance(X, np.ndarray) and not np.issubdtype(X.dtype.type, np.number): + dt = X.dtype.type + raise TypeError( + f"AutoPyTorch does not support numpy.ndarray with non-numerical dtype, but got {dt}" + ) def _get_columns_to_encode(self, X: pd.DataFrame) -> Tuple[List[str], List[str]]: """ diff --git a/test/test_data/test_feature_validator.py b/test/test_data/test_feature_validator.py index baf96a719..c0d497ad9 100644 --- a/test/test_data/test_feature_validator.py +++ b/test/test_data/test_feature_validator.py @@ -220,7 +220,7 @@ def test_featurevalidator_supported_types(input_data_featuretest): ) def test_featurevalidator_unsupported_numpy(input_data_featuretest): validator = TabularFeatureValidator() - with pytest.raises(ValueError, match=r".*When providing a numpy array.*not supported."): + with pytest.raises(TypeError, match=r"AutoPyTorch does not support numpy.ndarray with non-numerical dtype"): validator.fit(input_data_featuretest) @@ -330,7 +330,7 @@ def test_features_unsupported_calls_are_raised(): validator = TabularFeatureValidator() with pytest.raises(TypeError, match=r"invalid type `time and/or date datatype`."): validator.fit(pd.DataFrame({'datetime': [pd.Timestamp('20180310')]})) - with pytest.raises(ValueError, match=r"AutoPyTorch only supports.*yet, the provided input"): + with pytest.raises(TypeError, match=r"AutoPyTorch only supports numpy.ndarray, pandas.DataFrame"): validator.fit({'input1': 1, 'input2': 2}) with pytest.raises(TypeError, match=r"invalid type `string`."): validator.fit(pd.DataFrame([{'A': 1, 'B': 2}], dtype='string')) @@ -515,15 +515,16 @@ def test_featurevalidator_new_data_after_fit(openml_id, # And then check proper error messages if train_data_type == 'pandas': + pattern = r"of the features must be identical before/after fit()" old_dtypes = copy.deepcopy(validator.dtypes) validator.dtypes = ['dummy' for dtype in X_train.dtypes] - with pytest.raises(ValueError, match=r"Changing the dtype of the features after fit"): + with pytest.raises(ValueError, match=pattern): transformed_X = validator.transform(X_test) validator.dtypes = old_dtypes if test_data_type == 'pandas': columns = X_test.columns.tolist() X_test = X_test[reversed(columns)] - with pytest.raises(ValueError, match=r"Changing the column order of the features"): + with pytest.raises(ValueError, match=pattern): transformed_X = validator.transform(X_test) From 999ae64d3b7b36f8476e61a1aaae0728a2317c2c Mon Sep 17 00:00:00 2001 From: nabenabe0928 Date: Thu, 3 Mar 2022 07:51:56 +0900 Subject: [PATCH 08/10] [feat] [fix] Add num/cat enum and fix an error in a test --- autoPyTorch/data/tabular_feature_validator.py | 40 ++++++++++++++----- autoPyTorch/utils/common.py | 3 ++ test/test_data/test_validation.py | 2 +- 3 files changed, 34 insertions(+), 11 deletions(-) diff --git a/autoPyTorch/data/tabular_feature_validator.py b/autoPyTorch/data/tabular_feature_validator.py index 9fcd7d39f..7d6b63dd6 100644 --- a/autoPyTorch/data/tabular_feature_validator.py +++ b/autoPyTorch/data/tabular_feature_validator.py @@ -1,6 +1,12 @@ +""" +TODO: + 1. Add dtypes argument to TabularFeatureValidator + 2. Modify dtypes from List[str] to Dict[str, str] + 3. Add the feature to enforce the dtype to the provided dtypes +""" import functools from logging import Logger -from typing import Any, Dict, List, Mapping, Optional, Tuple, Union, cast +from typing import Any, Dict, List, Mapping, Optional, Set, Tuple, Union, cast import numpy as np @@ -23,10 +29,20 @@ DatasetDTypeContainerType, reduce_dataset_size_if_too_large ) -from autoPyTorch.utils.common import ispandas +from autoPyTorch.utils.common import autoPyTorchEnum, ispandas from autoPyTorch.utils.logging_ import PicklableClientLogger +class ColumnDTypes(autoPyTorchEnum): + numerical = "numerical" + categorical = "categorical" + + +def convert_dtype_enum_dict_to_str_dict(dtype_dict: Dict[str, ColumnDTypes]) -> Dict[str, str]: + enum2str = {type_choice: str(type_choice) for type_choice in ColumnDTypes} + return {col_name: enum2str[dtype_choice] for col_name, dtype_choice in dtype_dict.items()} + + def _create_column_transformer( preprocessors: Dict[str, List[BaseEstimator]], categorical_columns: List[str], @@ -129,6 +145,7 @@ def __init__( ) -> None: self._dataset_compression = dataset_compression self._reduced_dtype: Optional[DatasetDTypeContainerType] = None + self.all_nan_columns: Optional[Set[str]] = None super().__init__(logger) @staticmethod @@ -146,10 +163,12 @@ def _comparator(cmp1: str, cmp2: str) -> int: Returns: int: either [0, -1, 1] """ - choices = ['categorical', 'numerical'] + choices = [str(ColumnDTypes.categorical), str(ColumnDTypes.numerical)] if cmp1 not in choices or cmp2 not in choices: - raise ValueError('The comparator for the column order only accepts {}, ' - 'but got {} and {}'.format(choices, cmp1, cmp2)) + raise ValueError( + f"The comparator for the column order only accepts {choices}, " + f"but got {cmp1} and {cmp2}" + ) idx1, idx2 = choices.index(cmp1), choices.index(cmp2) return idx1 - idx2 @@ -164,7 +183,7 @@ def _convert_all_nan_columns_to_numeric(self, X: pd.DataFrame, fit: bool = False fit (bool): Whether this call is the fit to X or the transform using pre-fitted transformer. """ - if not fit and self.all_nan_columns is None: + if not fit and not issparse(X) and self.all_nan_columns is None: raise ValueError('_fit must be called before calling transform') if fit: @@ -200,6 +219,7 @@ def _encode_categories(self, X: pd.DataFrame) -> None: self.column_transformer.fit(X) # The column transformer moves categoricals to the left side + assert self.feat_type is not None self.feat_type = sorted(self.feat_type, key=functools.cmp_to_key(self._comparator)) encoded_categories = self.column_transformer.\ @@ -242,8 +262,8 @@ def _fit(self, X: SupportedFeatTypes) -> BaseEstimator: if len(self.enc_columns) > 0: self._encode_categories(X) - for i, type_ in enumerate(self.feat_type): - if 'numerical' in type_: + for i, type_name in enumerate(self.feat_type): + if ColumnDTypes.numerical in type_name: self.numerical_columns.append(i) else: self.categorical_columns.append(i) @@ -453,9 +473,9 @@ def _get_columns_to_encode(self, X: pd.DataFrame) -> Tuple[List[str], List[str]] for dtype, column in zip(X.dtypes, X.columns): if dtype.name in ['category', 'bool']: enc_columns.append(column) - feat_type.append('categorical') + feat_type.append(str(ColumnDTypes.categorical)) elif is_numeric_dtype(dtype): - feat_type.append('numerical') + feat_type.append(str(ColumnDTypes.numerical)) else: _error_due_to_unsupported_column(X, column) diff --git a/autoPyTorch/utils/common.py b/autoPyTorch/utils/common.py index 48302bdee..23d3908e7 100644 --- a/autoPyTorch/utils/common.py +++ b/autoPyTorch/utils/common.py @@ -101,6 +101,9 @@ def __eq__(self, other: Any) -> bool: def __hash__(self) -> int: return hash(self.value) + def __str__(self) -> str: + return str(self.value) + def custom_collate_fn(batch: List) -> List[Optional[torch.Tensor]]: """ diff --git a/test/test_data/test_validation.py b/test/test_data/test_validation.py index 482c99769..cc89f5276 100644 --- a/test/test_data/test_validation.py +++ b/test/test_data/test_validation.py @@ -103,7 +103,7 @@ def test_sparse_data_validation_for_regression(): validator.fit(X_train=X_sp, y_train=y) - X_t, y_t = validator.transform(X, y) + X_t, y_t = validator.transform(X_sp, y) assert np.shape(X) == np.shape(X_t) # make sure everything was encoded to number From 6bec5c41a6ceb860484e448ce6185f3b4b56f95d Mon Sep 17 00:00:00 2001 From: nabenabe0928 Date: Thu, 3 Mar 2022 18:45:29 +0900 Subject: [PATCH 09/10] [refactor] Separate many functions into utils --- autoPyTorch/data/base_feature_validator.py | 4 +- autoPyTorch/data/tabular_feature_validator.py | 162 +++++------------- autoPyTorch/data/tabular_target_validator.py | 17 +- autoPyTorch/data/utils.py | 125 +++++++++++++- 4 files changed, 170 insertions(+), 138 deletions(-) diff --git a/autoPyTorch/data/base_feature_validator.py b/autoPyTorch/data/base_feature_validator.py index 2b5183550..d24a2eb4f 100644 --- a/autoPyTorch/data/base_feature_validator.py +++ b/autoPyTorch/data/base_feature_validator.py @@ -1,5 +1,5 @@ import logging -from typing import List, Optional, Union +from typing import Dict, List, Optional, Union import numpy as np @@ -37,7 +37,7 @@ def __init__( # Register types to detect unsupported data format changes self.feat_type: Optional[List[str]] = None self.data_type: Optional[type] = None - self.dtypes: List[str] = [] + self.dtypes: Dict[str, str] = {} self.column_order: List[str] = [] self.column_transformer: Optional[BaseEstimator] = None diff --git a/autoPyTorch/data/tabular_feature_validator.py b/autoPyTorch/data/tabular_feature_validator.py index 7d6b63dd6..282f8508a 100644 --- a/autoPyTorch/data/tabular_feature_validator.py +++ b/autoPyTorch/data/tabular_feature_validator.py @@ -15,7 +15,6 @@ from scipy.sparse import issparse, spmatrix -import sklearn.utils from sklearn import preprocessing from sklearn.base import BaseEstimator from sklearn.compose import ColumnTransformer @@ -25,24 +24,19 @@ from autoPyTorch.data.base_feature_validator import BaseFeatureValidator, SupportedFeatTypes from autoPyTorch.data.utils import ( + ColumnDTypes, DatasetCompressionInputType, DatasetDTypeContainerType, + _categorical_left_mover, + _check_and_to_array, + _get_columns_to_encode, + has_object_columns, reduce_dataset_size_if_too_large ) -from autoPyTorch.utils.common import autoPyTorchEnum, ispandas +from autoPyTorch.utils.common import ispandas from autoPyTorch.utils.logging_ import PicklableClientLogger -class ColumnDTypes(autoPyTorchEnum): - numerical = "numerical" - categorical = "categorical" - - -def convert_dtype_enum_dict_to_str_dict(dtype_dict: Dict[str, ColumnDTypes]) -> Dict[str, str]: - enum2str = {type_choice: str(type_choice) for type_choice in ColumnDTypes} - return {col_name: enum2str[dtype_choice] for col_name, dtype_choice in dtype_dict.items()} - - def _create_column_transformer( preprocessors: Dict[str, List[BaseEstimator]], categorical_columns: List[str], @@ -90,28 +84,6 @@ def get_tabular_preprocessors() -> Dict[str, List[BaseEstimator]]: return preprocessors -def _error_due_to_unsupported_column(X: pd.DataFrame, column: str) -> None: - # Move away from np.issubdtype as it causes - # TypeError: data type not understood in certain pandas types - def _generate_error_message_prefix(type_name: str, proc_type: Optional[str] = None) -> str: - msg1 = f"column `{column}` has an invalid type `{type_name}`. " - msg2 = "Cast it to a numerical type, category type or bool type by astype method. " - msg3 = f"The following link might help you to know {proc_type} processing: " - return msg1 + msg2 + ("" if proc_type is None else msg3) - - dtype = X[column].dtype - if dtype.name == 'object': - err_msg = _generate_error_message_prefix(type_name="object", proc_type="string") - url = "https://scikit-learn.org/stable/tutorial/text_analytics/working_with_text_data.html" - raise TypeError(f"{err_msg}{url}") - elif pd.core.dtypes.common.is_datetime_or_timedelta_dtype(dtype): - err_msg = _generate_error_message_prefix(type_name="time and/or date datatype", proc_type="datetime") - raise TypeError(f"{err_msg}https://stats.stackexchange.com/questions/311494/") - else: - err_msg = _generate_error_message_prefix(type_name=dtype.name) - raise TypeError(err_msg) - - class TabularFeatureValidator(BaseFeatureValidator): """ A subclass of `BaseFeatureValidator` made for tabular data. @@ -142,36 +114,14 @@ def __init__( self, logger: Optional[Union[PicklableClientLogger, Logger]] = None, dataset_compression: Optional[Mapping[str, Any]] = None, + dtypes: Optional[Dict[str, str]] = None, ) -> None: + super().__init__(logger) self._dataset_compression = dataset_compression self._reduced_dtype: Optional[DatasetDTypeContainerType] = None self.all_nan_columns: Optional[Set[str]] = None - super().__init__(logger) - - @staticmethod - def _comparator(cmp1: str, cmp2: str) -> int: - """Order so that categorical columns come left and numerical columns come right - - Args: - cmp1 (str): First variable to compare - cmp2 (str): Second variable to compare - - Raises: - ValueError: if the values of the variables to compare - are not in 'categorical' or 'numerical' - - Returns: - int: either [0, -1, 1] - """ - choices = [str(ColumnDTypes.categorical), str(ColumnDTypes.numerical)] - if cmp1 not in choices or cmp2 not in choices: - raise ValueError( - f"The comparator for the column order only accepts {choices}, " - f"but got {cmp1} and {cmp2}" - ) - - idx1, idx2 = choices.index(cmp1), choices.index(cmp2) - return idx1 - idx2 + self.dtypes = dtypes if dtypes is not None else {} + self._called_infer_object = False def _convert_all_nan_columns_to_numeric(self, X: pd.DataFrame, fit: bool = False) -> pd.DataFrame: """ @@ -196,7 +146,7 @@ def _convert_all_nan_columns_to_numeric(self, X: pd.DataFrame, fit: bool = False X[col] = np.nan X[col] = pd.to_numeric(X[col]) if len(self.dtypes): - self.dtypes[list(X.columns).index(col)] = X[col].dtype + self.dtypes[col] = X[col].dtype.name if has_object_columns(X.dtypes.values): X = self.infer_objects(X) @@ -208,6 +158,10 @@ def _convert_all_nan_columns_to_numeric(self, X: pd.DataFrame, fit: bool = False return X + @staticmethod + def _comparator(cmp1: str, cmp2: str) -> int: + return _categorical_left_mover(cmp1, cmp2) + def _encode_categories(self, X: pd.DataFrame) -> None: preprocessors = get_tabular_preprocessors() self.column_transformer = _create_column_transformer( @@ -345,7 +299,7 @@ def transform(self, X: SupportedFeatTypes) -> Union[np.ndarray, spmatrix, pd.Dat X.sort_indices() try: - X = sklearn.utils.check_array(X, force_all_finite=False, accept_sparse='csr') + X = _check_and_to_array(X) except Exception as e: self.logger.exception( f"Conversion failed for input {X.dtypes} {X}" @@ -391,9 +345,6 @@ def _check_dataframe(self, X: pd.DataFrame) -> None: err_msg = " of the features must be identical before/after fit(), " err_msg += "but different between training and test datasets:\n" - if has_object_columns(X.dtypes.values): - X = self.infer_objects(X) - # Define the column to be encoded as the feature validator is fitted once per estimator self.enc_columns, self.feat_type = self._get_columns_to_encode(X) @@ -403,7 +354,7 @@ def _check_dataframe(self, X: pd.DataFrame) -> None: elif self.column_order != column_order: raise ValueError(f"The column order{err_msg}train: {self.column_order}\ntest: {column_order}") - dtypes = [dtype.name for dtype in X.dtypes] + dtypes = {col: dtype.name for col, dtype in zip(X.columns, X.dtypes)} if len(self.dtypes) == 0: self.dtypes = dtypes elif self.dtypes != dtypes: @@ -444,7 +395,7 @@ def _check_data(self, X: SupportedFeatTypes) -> None: f"AutoPyTorch does not support numpy.ndarray with non-numerical dtype, but got {dt}" ) - def _get_columns_to_encode(self, X: pd.DataFrame) -> Tuple[List[str], List[str]]: + def _get_columns_to_encode(self, X: pd.DataFrame) -> Tuple[List[str], Dict[str, str]]: """ Return the columns to be encoded from a pandas dataframe @@ -455,31 +406,15 @@ def _get_columns_to_encode(self, X: pd.DataFrame) -> Tuple[List[str], List[str]] Returns: enc_columns (List[str]): - Columns to encode, if any - feat_type: - Type of each column numerical/categorical + Columns to encode + feat_type (Dict[str, str]): + Whether each column is numerical or categorical """ if len(self.enc_columns) > 0 and self.feat_type is not None: return self.enc_columns, self.feat_type - - # Register if a column needs encoding - enc_columns = [] - - # Also, register the feature types for the estimator - feat_type = [] - - # Make sure each column is a valid type - for dtype, column in zip(X.dtypes, X.columns): - if dtype.name in ['category', 'bool']: - enc_columns.append(column) - feat_type.append(str(ColumnDTypes.categorical)) - elif is_numeric_dtype(dtype): - feat_type.append(str(ColumnDTypes.numerical)) - else: - _error_due_to_unsupported_column(X, column) - - return enc_columns, feat_type + else: + return _get_columns_to_encode(X) def list_to_pandas(self, X: SupportedFeatTypes) -> pd.DataFrame: """ @@ -530,44 +465,25 @@ def infer_objects(self, X: pd.DataFrame) -> pd.DataFrame: Returns: pd.DataFrame """ - if hasattr(self, 'object_dtype_mapping'): + if self._called_infer_object: # honor the training data types try: # Mypy does not process the has attr. - X = X.astype(self.object_dtype_mapping) # type: ignore[has-type] + X = X.astype(self.dtypes) # type: ignore[has-type] except Exception as e: - # Try inference if possible - self.logger.warning(f'Casting the columns to training dtypes ' # type: ignore[has-type] - f'{self.object_dtype_mapping} caused the exception {e}') - pass - else: - if len(self.dtypes) != 0: - # when train data has no object dtype, but test does - # we prioritise the datatype given in training data - dtype_dict = {col: dtype for col, dtype in zip(X.columns, self.dtypes)} - X = X.astype(dtype_dict) - else: - # Calling for the first time to infer the categories - X = X.infer_objects() - dtype_dict = {col: 'category' for col, dtype in zip(X.columns, X.dtypes) if not is_numeric_dtype(dtype)} - X = X.astype(dtype_dict) - # only numerical attributes and categories - self.object_dtype_mapping = {column: data_type for column, data_type in zip(X.columns, X.dtypes)} - - self.logger.debug(f"Infer Objects: {self.object_dtype_mapping}") + self.logger.warning( + 'Casting the columns to training dtypes ' + f'{self.dtypes} caused the exception {e}' # type: ignore[has-type] + ) + elif len(self.dtypes): # Overwrite the dtypes in test data by those in the training data + X = X.astype(self.dtypes) + else: # Calling for the first time to infer the categories + X = X.infer_objects() + cat_dtypes = {col: 'category' for col, dtype in zip(X.columns, X.dtypes) if not is_numeric_dtype(dtype)} + X = X.astype(cat_dtypes) + + self.dtypes.update({col: dtype.name for col, dtype in zip(X.columns, X.dtypes)}) + self.logger.debug(f"New dtypes of data: {self.dtypes}") + self._called_infer_object = True return X - - -def has_object_columns(feature_types: pd.Series) -> bool: - """ - Indicate whether on a Series of dtypes for a Pandas DataFrame - there exists one or more object columns. - Args: - feature_types (pd.Series): The feature types for a DataFrame. - Returns: - bool: - True if the DataFrame dtypes contain an object column, False - otherwise. - """ - return np.dtype('O') in feature_types diff --git a/autoPyTorch/data/tabular_target_validator.py b/autoPyTorch/data/tabular_target_validator.py index 22cabb999..693a24cae 100644 --- a/autoPyTorch/data/tabular_target_validator.py +++ b/autoPyTorch/data/tabular_target_validator.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Union, cast +from typing import List, Optional, cast import numpy as np @@ -7,24 +7,16 @@ from scipy.sparse import issparse, spmatrix -import sklearn.utils from sklearn import preprocessing from sklearn.base import BaseEstimator from sklearn.exceptions import NotFittedError from sklearn.utils.multiclass import type_of_target from autoPyTorch.data.base_target_validator import BaseTargetValidator, SupportedTargetTypes +from autoPyTorch.data.utils import ArrayType, _check_and_to_array from autoPyTorch.utils.common import ispandas -ArrayType = Union[np.ndarray, spmatrix] - - -def _check_and_to_array(y: SupportedTargetTypes) -> ArrayType: - """ sklearn check array will make sure we have the correct numerical features for the array """ - return sklearn.utils.check_array(y, force_all_finite=True, accept_sparse='csr', ensure_2d=False) - - def _modify_regression_target(y: ArrayType) -> ArrayType: # Regression targets must have numbers after a decimal point. # Ref: https://github.com/scikit-learn/scikit-learn/issues/8952 @@ -124,8 +116,9 @@ def _fit( return self def _transform_by_encoder(self, y: SupportedTargetTypes) -> np.ndarray: + kwargs = dict(force_all_finite=True, ensure_2d=False) if self.encoder is None: - return _check_and_to_array(y) + return _check_and_to_array(y, **kwargs) # remove ravel warning from pandas Series shape = np.shape(y) @@ -139,7 +132,7 @@ def _transform_by_encoder(self, y: SupportedTargetTypes) -> np.ndarray: else: y = self.encoder.transform(np.array(y).reshape(-1, 1)).reshape(-1) - return _check_and_to_array(y) + return _check_and_to_array(y, **kwargs) def transform(self, y: SupportedTargetTypes) -> np.ndarray: """ diff --git a/autoPyTorch/data/utils.py b/autoPyTorch/data/utils.py index 03375ce27..410cb03bb 100644 --- a/autoPyTorch/data/utils.py +++ b/autoPyTorch/data/utils.py @@ -18,11 +18,18 @@ import numpy as np import pandas as pd +from pandas.api.types import is_numeric_dtype from scipy.sparse import issparse, spmatrix -from autoPyTorch.utils.common import ispandas +from sklearn.utils import check_array +from autoPyTorch.data.base_target_validator import SupportedTargetTypes +from autoPyTorch.data.base_feature_validator import SupportedFeatTypes +from autoPyTorch.utils.common import autoPyTorchEnum, ispandas + + +ArrayType = Union[np.ndarray, spmatrix] # TODO: TypedDict with python 3.8 # @@ -39,6 +46,122 @@ } +class ColumnDTypes(autoPyTorchEnum): + numerical = "numerical" + categorical = "categorical" + + +def convert_dtype_enum_dict_to_str_dict(dtype_dict: Dict[str, ColumnDTypes]) -> Dict[str, str]: + enum2str = {type_choice: str(type_choice) for type_choice in ColumnDTypes} + return {col_name: enum2str[dtype_choice] for col_name, dtype_choice in dtype_dict.items()} + + +def has_object_columns(feature_types: pd.Series) -> bool: + """ + Indicate whether on a Series of dtypes for a Pandas DataFrame + there exists one or more object columns. + Args: + feature_types (pd.Series): The feature types for a DataFrame. + Returns: + bool: + True if the DataFrame dtypes contain an object column, False + otherwise. + """ + return np.dtype('O') in feature_types + + +def _check_and_to_array( + data: Union[SupportedFeatTypes, SupportedTargetTypes], + **kwargs: Dict[str, Any] +) -> ArrayType: + """ sklearn check array will make sure we have the correct numerical features for the array """ + _kwargs = dict(accept_sparse='csr', force_all_finite=False) + _kwargs.update(kwargs) + return check_array(data, **_kwargs) + + +def _error_due_to_unsupported_column(X: pd.DataFrame, column: str) -> None: + # Move away from np.issubdtype as it causes + # TypeError: data type not understood in certain pandas types + def _generate_error_message_prefix(type_name: str, proc_type: Optional[str] = None) -> str: + msg1 = f"column `{column}` has an invalid type `{type_name}`. " + msg2 = "Cast it to a numerical type, category type or bool type by astype method. " + msg3 = f"The following link might help you to know {proc_type} processing: " + return msg1 + msg2 + ("" if proc_type is None else msg3) + + dtype = X[column].dtype + if dtype.name == 'object': + err_msg = _generate_error_message_prefix(type_name="object", proc_type="string") + url = "https://scikit-learn.org/stable/tutorial/text_analytics/working_with_text_data.html" + raise TypeError(f"{err_msg}{url}") + elif pd.core.dtypes.common.is_datetime_or_timedelta_dtype(dtype): + err_msg = _generate_error_message_prefix(type_name="time and/or date datatype", proc_type="datetime") + raise TypeError(f"{err_msg}https://stats.stackexchange.com/questions/311494/") + else: + err_msg = _generate_error_message_prefix(type_name=dtype.name) + raise TypeError(err_msg) + + +def _get_columns_to_encode(X: pd.DataFrame) -> Tuple[List[str], Dict[str, str]]: + """ + In case input data is a pandas DataFrame, this utility encodes the user provided + features (from categorical for example) to a numerical value that further stages + will be able to use + + Args: + X (pd.DataFrame): + A set of features that are going to be validated (type and dimensionality + checks) and an encoder fitted in the case the data needs encoding + + Returns: + enc_columns (List[str]): + Columns to encode + feat_type (Dict[str, str]): + Whether each column is numerical or categorical + """ + enc_columns: List[str] = [] + # feat_type: Dict[str, str] = {} + feat_type: List[str] = [] + + for dtype, col in zip(X.dtypes, X.columns): + if dtype.name in ['category', 'bool']: + enc_columns.append(col) + # feat_type[col] = str(ColumnDTypes.categorical) + feat_type.append(str(ColumnDTypes.categorical)) + elif is_numeric_dtype(dtype): + # feat_type[col] = str(ColumnDTypes.numerical) + feat_type.append(str(ColumnDTypes.numerical)) + else: + _error_due_to_unsupported_column(X, col) + + return enc_columns, feat_type + + +def _categorical_left_mover(cmp1: str, cmp2: str) -> int: + """Order so that categorical columns come left and numerical columns come right + + Args: + cmp1 (str): First variable to compare + cmp2 (str): Second variable to compare + + Raises: + ValueError: if the values of the variables to compare + are not in 'categorical' or 'numerical' + + Returns: + int: either [0, -1, 1] + """ + choices = [str(ColumnDTypes.categorical), str(ColumnDTypes.numerical)] + if cmp1 not in choices or cmp2 not in choices: + raise ValueError( + f"The comparator for the column order only accepts {choices}, " + f"but got {cmp1} and {cmp2}" + ) + + idx1, idx2 = choices.index(cmp1), choices.index(cmp2) + return idx1 - idx2 + + def get_dataset_compression_mapping( memory_limit: int, dataset_compression: Union[bool, Mapping[str, Any]] From 2d9f4f498ba61ba8c6843a4f7e1a41e73075fea2 Mon Sep 17 00:00:00 2001 From: nabenabe0928 Date: Thu, 3 Mar 2022 19:06:40 +0900 Subject: [PATCH 10/10] [refactor] Separate processes as a function and add them to util --- autoPyTorch/data/base_feature_validator.py | 12 +-- autoPyTorch/data/base_target_validator.py | 8 +- autoPyTorch/data/tabular_feature_validator.py | 64 ++-------------- autoPyTorch/data/utils.py | 73 ++++++++++++++++++- 4 files changed, 80 insertions(+), 77 deletions(-) diff --git a/autoPyTorch/data/base_feature_validator.py b/autoPyTorch/data/base_feature_validator.py index d24a2eb4f..4f64b429a 100644 --- a/autoPyTorch/data/base_feature_validator.py +++ b/autoPyTorch/data/base_feature_validator.py @@ -3,18 +3,12 @@ import numpy as np -import pandas as pd - -from scipy.sparse import spmatrix - from sklearn.base import BaseEstimator +from autoPyTorch.data.utils import SupportedFeatTypes, list_to_pandas from autoPyTorch.utils.logging_ import PicklableClientLogger -SupportedFeatTypes = Union[List, pd.DataFrame, np.ndarray, spmatrix] - - class BaseFeatureValidator(BaseEstimator): """ A class to pre-process features. In this regards, the format of the data is checked, @@ -75,8 +69,8 @@ def fit( # If a list was provided, it will be converted to pandas if isinstance(X_train, list): - X_train = self.list_to_pandas(X_train) - X_test = self.list_to_pandas(X_test) if X_test is not None else None + X_train = list_to_pandas(X_train, self.logger) + X_test = list_to_pandas(X_test, self.logger) if X_test is not None else None self._check_data(X_train) diff --git a/autoPyTorch/data/base_target_validator.py b/autoPyTorch/data/base_target_validator.py index 530675fbd..f5099ff62 100644 --- a/autoPyTorch/data/base_target_validator.py +++ b/autoPyTorch/data/base_target_validator.py @@ -1,18 +1,14 @@ import logging -from typing import List, Optional, Union, cast +from typing import Optional, Union, cast import numpy as np import pandas as pd -from scipy.sparse import spmatrix - from sklearn.base import BaseEstimator from autoPyTorch.utils.logging_ import PicklableClientLogger - - -SupportedTargetTypes = Union[List, pd.Series, pd.DataFrame, np.ndarray, spmatrix] +from autoPyTorch.data.utils import SupportedTargetTypes class BaseTargetValidator(BaseEstimator): diff --git a/autoPyTorch/data/tabular_feature_validator.py b/autoPyTorch/data/tabular_feature_validator.py index 282f8508a..cecd90257 100644 --- a/autoPyTorch/data/tabular_feature_validator.py +++ b/autoPyTorch/data/tabular_feature_validator.py @@ -31,7 +31,8 @@ _check_and_to_array, _get_columns_to_encode, has_object_columns, - reduce_dataset_size_if_too_large + reduce_dataset_size_if_too_large, + to_pandas, ) from autoPyTorch.utils.common import ispandas from autoPyTorch.utils.logging_ import PicklableClientLogger @@ -202,10 +203,7 @@ def _fit(self, X: SupportedFeatTypes) -> BaseEstimator: The fitted base estimator """ - # The final output of a validator is a numpy array. But pandas - # gives us information about the column dtype - if isinstance(X, np.ndarray): - X = self.numpy_to_pandas(X) + X = to_pandas(X) # there is the column dtype info, so convert it to pandas if ispandas(X) and not issparse(X): X = cast(pd.DataFrame, X) @@ -277,12 +275,7 @@ def transform(self, X: SupportedFeatTypes) -> Union[np.ndarray, spmatrix, pd.Dat if not self._is_fitted: raise NotFittedError("Cannot call transform on a validator that is not fitted") - # If a list was provided, it will be converted to pandas - if isinstance(X, list): - X = self.list_to_pandas(X) - elif isinstance(X, np.ndarray): - X = self.numpy_to_pandas(X) - + X = to_pandas(X) if ispandas(X) and not issparse(X): X = self._convert_all_nan_columns_to_numeric(X) @@ -298,18 +291,7 @@ def transform(self, X: SupportedFeatTypes) -> Union[np.ndarray, spmatrix, pd.Dat if issparse(X) and hasattr(X, 'sort_indices'): X.sort_indices() - try: - X = _check_and_to_array(X) - except Exception as e: - self.logger.exception( - f"Conversion failed for input {X.dtypes} {X}" - "This means AutoPyTorch was not able to properly " - "Extract the dtypes of the provided input features. " - "Please try to manually cast it to a supported " - "numerical or categorical values." - ) - raise e - + X = _check_and_to_array(X, logger=self.logger) X = self._compress_dataset(X) return X @@ -416,42 +398,6 @@ def _get_columns_to_encode(self, X: pd.DataFrame) -> Tuple[List[str], Dict[str, else: return _get_columns_to_encode(X) - def list_to_pandas(self, X: SupportedFeatTypes) -> pd.DataFrame: - """ - Convert a list to a pandas DataFrame. In this process, column types are inferred. - - Args: - X (SupportedFeatTypes): - A set of features that are going to be validated (type and dimensionality - checks) and an encoder fitted in the case the data needs encoding - - Returns: - pd.DataFrame: - transformed data from list to pandas DataFrame - """ - - # If a list was provided, it will be converted to pandas - X = pd.DataFrame(data=X).infer_objects() - data_info = [(col, t) for col, t in zip(X.columns, X.dtypes)] - self.logger.warning( - "The provided feature types to AutoPyTorch are list." - f"Features have been interpreted as: {data_info}" - ) - return X - - def numpy_to_pandas(self, X: np.ndarray) -> pd.DataFrame: - """ - Converts a numpy array to pandas for type inference - - Args: - X (np.ndarray): - data to be interpreted. - - Returns: - pd.DataFrame - """ - return pd.DataFrame(X).infer_objects().convert_dtypes() - def infer_objects(self, X: pd.DataFrame) -> pd.DataFrame: """ In case the input contains object columns, their type is inferred if possible diff --git a/autoPyTorch/data/utils.py b/autoPyTorch/data/utils.py index 410cb03bb..40dc7aa11 100644 --- a/autoPyTorch/data/utils.py +++ b/autoPyTorch/data/utils.py @@ -1,5 +1,6 @@ # Implementation used from https://github.com/automl/auto-sklearn/blob/development/autosklearn/util/data.py import warnings +from logging import Logger from math import floor from typing import ( Any, @@ -24,12 +25,12 @@ from sklearn.utils import check_array -from autoPyTorch.data.base_target_validator import SupportedTargetTypes -from autoPyTorch.data.base_feature_validator import SupportedFeatTypes from autoPyTorch.utils.common import autoPyTorchEnum, ispandas ArrayType = Union[np.ndarray, spmatrix] +SupportedFeatTypes = Union[List, pd.DataFrame, np.ndarray, spmatrix] +SupportedTargetTypes = Union[List, pd.Series, pd.DataFrame, np.ndarray, spmatrix] # TODO: TypedDict with python 3.8 # @@ -56,6 +57,60 @@ def convert_dtype_enum_dict_to_str_dict(dtype_dict: Dict[str, ColumnDTypes]) -> return {col_name: enum2str[dtype_choice] for col_name, dtype_choice in dtype_dict.items()} +def list_to_pandas(data: List, logger: Optional[Logger] = None) -> pd.DataFrame: + """ + Convert a list to a pandas DataFrame. In this process, column types are inferred. + + Args: + data (List): + A list of features. + + Returns: + pd.DataFrame: + transformed data from list to pandas DataFrame + """ + if not isinstance(data, list): + raise TypeError(f"data must be list, but got {type(data)}") + + # If a list was provided, it will be converted to pandas + data = pd.DataFrame(data=data).infer_objects() + data_info = [(col, t) for col, t in zip(data.columns, data.dtypes)] + + if logger is not None: + logger.warning( + "The provided feature types to AutoPyTorch are list." + f"Features have been interpreted as: {data_info}" + ) + + return data + + +def numpy_to_pandas(data: np.ndarray) -> pd.DataFrame: + """ + Converts a numpy array to pandas for type inference + + Args: + X (np.ndarray): + data to be interpreted. + + Returns: + pd.DataFrame + """ + if not isinstance(data, np.ndarray): + raise TypeError(f"data must be np.ndarray, but got {type(data)}") + + return pd.DataFrame(data).infer_objects().convert_dtypes() + + +def to_pandas(data: SupportedFeatTypes, logger: Optional[Logger] = None) -> SupportedFeatTypes: + if isinstance(data, list): + data = list_to_pandas(data, logger) + elif isinstance(data, np.ndarray): + data = numpy_to_pandas(data) + + return data + + def has_object_columns(feature_types: pd.Series) -> bool: """ Indicate whether on a Series of dtypes for a Pandas DataFrame @@ -72,12 +127,24 @@ def has_object_columns(feature_types: pd.Series) -> bool: def _check_and_to_array( data: Union[SupportedFeatTypes, SupportedTargetTypes], + logger: Optional[Logger] = None, **kwargs: Dict[str, Any] ) -> ArrayType: """ sklearn check array will make sure we have the correct numerical features for the array """ _kwargs = dict(accept_sparse='csr', force_all_finite=False) _kwargs.update(kwargs) - return check_array(data, **_kwargs) + try: + return check_array(data, **_kwargs) + except Exception as e: + if logger is not None: + logger.exception( + f"Conversion failed for input {data}" + "This means AutoPyTorch was not able to properly " + "Extract the dtypes of the provided input features. " + "Please try to manually cast it to a supported " + "numerical or categorical values." + ) + raise e def _error_due_to_unsupported_column(X: pd.DataFrame, column: str) -> None: