Skip to content
59 changes: 55 additions & 4 deletions autoPyTorch/api/tabular_classification.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union

import numpy as np

Expand All @@ -11,6 +11,10 @@
TASK_TYPES_TO_STRING,
)
from autoPyTorch.data.tabular_validator import TabularInputValidator
from autoPyTorch.data.utils import (
default_dataset_compression_arg,
validate_dataset_compression_arg
)
from autoPyTorch.datasets.base_dataset import BaseDatasetPropertiesType
from autoPyTorch.datasets.resampling_strategy import (
HoldoutValTypes,
Expand Down Expand Up @@ -163,6 +167,7 @@ def _get_dataset_input_validator(
resampling_strategy: Optional[ResamplingStrategies] = None,
resampling_strategy_args: Optional[Dict[str, Any]] = None,
dataset_name: Optional[str] = None,
dataset_compression: Optional[Mapping[str, Any]] = None,
) -> Tuple[TabularDataset, TabularInputValidator]:
"""
Returns an object of `TabularDataset` and an object of
Expand Down Expand Up @@ -202,6 +207,7 @@ def _get_dataset_input_validator(
InputValidator = TabularInputValidator(
is_classification=True,
logger_port=self._logger_port,
dataset_compression=dataset_compression
)

# Fit a input validator to check the provided data
Expand Down Expand Up @@ -234,14 +240,15 @@ def search(
total_walltime_limit: int = 100,
func_eval_time_limit_secs: Optional[int] = None,
enable_traditional_pipeline: bool = True,
memory_limit: Optional[int] = 4096,
memory_limit: int = 4096,
smac_scenario_args: Optional[Dict[str, Any]] = None,
get_smac_object_callback: Optional[Callable] = None,
all_supported_metrics: bool = True,
precision: int = 32,
disable_file_output: Optional[List[Union[str, DisableFileOutputParameters]]] = None,
load_models: bool = True,
portfolio_selection: Optional[str] = None,
dataset_compression: Union[Mapping[str, Any], bool] = False,
) -> 'BaseTask':
"""
Search for the best pipeline configuration for the given dataset.
Expand Down Expand Up @@ -310,7 +317,7 @@ def search(
feature by turning this flag to False. All machine learning
algorithms that are fitted during search() are considered for
ensemble building.
memory_limit (Optional[int]: default=4096):
memory_limit (int: default=4096):
Memory limit in MB for the machine learning algorithm.
Autopytorch will stop fitting the machine learning algorithm
if it tries to allocate more than memory_limit MB. If None
Expand Down Expand Up @@ -368,11 +375,54 @@ def search(
Additionally, the keyword 'greedy' is supported,
which would use the default portfolio from
`AutoPyTorch Tabular <https://arxiv.org/abs/2006.13799>`_.
dataset_compression: Union[bool, Mapping[str, Any]] = True
We compress datasets so that they fit into some predefined amount of memory.
**NOTE**

Default configuration when left as ``True``:
.. code-block:: python
{
"memory_allocation": 0.1,
"methods": ["precision"]
}
You can also pass your own configuration with the same keys and choosing
from the available ``"methods"``.
The available options are described here:
**memory_allocation**
By default, we attempt to fit the dataset into ``0.1 * memory_limit``. This
float value can be set with ``"memory_allocation": 0.1``. We also allow for
specifying absolute memory in MB, e.g. 10MB is ``"memory_allocation": 10``.
The memory used by the dataset is checked after each reduction method is
performed. If the dataset fits into the allocated memory, any further methods
listed in ``"methods"`` will not be performed.

**methods**
We currently provide the following methods for reducing the dataset size.
These can be provided in a list and are performed in the order as given.
* ``"precision"`` - We reduce floating point precision as follows:
* ``np.float128 -> np.float64``
* ``np.float96 -> np.float64``
* ``np.float64 -> np.float32``
* pandas dataframes are reduced using the downcast option of `pd.to_numeric`
to the lowest possible precision.

Returns:
self

"""
self._dataset_compression: Optional[Mapping[str, Any]]

if isinstance(dataset_compression, bool):
if dataset_compression is True:
self._dataset_compression = default_dataset_compression_arg
else:
self._dataset_compression = None
else:
self._dataset_compression = dataset_compression

if self._dataset_compression is not None:
self._dataset_compression = validate_dataset_compression_arg(
self._dataset_compression, memory_limit=memory_limit)

self.dataset, self.InputValidator = self._get_dataset_input_validator(
X_train=X_train,
Expand All @@ -381,7 +431,8 @@ def search(
y_test=y_test,
resampling_strategy=self.resampling_strategy,
resampling_strategy_args=self.resampling_strategy_args,
dataset_name=dataset_name)
dataset_name=dataset_name,
dataset_compression=self._dataset_compression)

return self._search(
dataset=self.dataset,
Expand Down
60 changes: 56 additions & 4 deletions autoPyTorch/api/tabular_regression.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union

import numpy as np

Expand All @@ -11,6 +11,10 @@
TASK_TYPES_TO_STRING
)
from autoPyTorch.data.tabular_validator import TabularInputValidator
from autoPyTorch.data.utils import (
default_dataset_compression_arg,
validate_dataset_compression_arg
)
from autoPyTorch.datasets.base_dataset import BaseDatasetPropertiesType
from autoPyTorch.datasets.resampling_strategy import (
HoldoutValTypes,
Expand Down Expand Up @@ -164,6 +168,7 @@ def _get_dataset_input_validator(
resampling_strategy: Optional[ResamplingStrategies] = None,
resampling_strategy_args: Optional[Dict[str, Any]] = None,
dataset_name: Optional[str] = None,
dataset_compression: Optional[Mapping[str, Any]] = None,
) -> Tuple[TabularDataset, TabularInputValidator]:
"""
Returns an object of `TabularDataset` and an object of
Expand Down Expand Up @@ -203,6 +208,7 @@ def _get_dataset_input_validator(
InputValidator = TabularInputValidator(
is_classification=False,
logger_port=self._logger_port,
dataset_compression=dataset_compression
)

# Fit a input validator to check the provided data
Expand Down Expand Up @@ -235,14 +241,15 @@ def search(
total_walltime_limit: int = 100,
func_eval_time_limit_secs: Optional[int] = None,
enable_traditional_pipeline: bool = True,
memory_limit: Optional[int] = 4096,
memory_limit: int = 4096,
smac_scenario_args: Optional[Dict[str, Any]] = None,
get_smac_object_callback: Optional[Callable] = None,
all_supported_metrics: bool = True,
precision: int = 32,
disable_file_output: Optional[List[Union[str, DisableFileOutputParameters]]] = None,
load_models: bool = True,
portfolio_selection: Optional[str] = None,
dataset_compression: Union[Mapping[str, Any], bool] = False,
) -> 'BaseTask':
"""
Search for the best pipeline configuration for the given dataset.
Expand Down Expand Up @@ -311,7 +318,7 @@ def search(
feature by turning this flag to False. All machine learning
algorithms that are fitted during search() are considered for
ensemble building.
memory_limit (Optional[int]: default=4096):
memory_limit (int: default=4096):
Memory limit in MB for the machine learning algorithm.
Autopytorch will stop fitting the machine learning algorithm
if it tries to allocate more than memory_limit MB. If None
Expand Down Expand Up @@ -369,19 +376,64 @@ def search(
Additionally, the keyword 'greedy' is supported,
which would use the default portfolio from
`AutoPyTorch Tabular <https://arxiv.org/abs/2006.13799>`_.
dataset_compression: Union[bool, Mapping[str, Any]] = True
We compress datasets so that they fit into some predefined amount of memory.
**NOTE**

Default configuration when left as ``True``:
.. code-block:: python
{
"memory_allocation": 0.1,
"methods": ["precision"]
}
You can also pass your own configuration with the same keys and choosing
from the available ``"methods"``.
The available options are described here:
**memory_allocation**
By default, we attempt to fit the dataset into ``0.1 * memory_limit``. This
float value can be set with ``"memory_allocation": 0.1``. We also allow for
specifying absolute memory in MB, e.g. 10MB is ``"memory_allocation": 10``.
The memory used by the dataset is checked after each reduction method is
performed. If the dataset fits into the allocated memory, any further methods
listed in ``"methods"`` will not be performed.

**methods**
We currently provide the following methods for reducing the dataset size.
These can be provided in a list and are performed in the order as given.
* ``"precision"`` - We reduce floating point precision as follows:
* ``np.float128 -> np.float64``
* ``np.float96 -> np.float64``
* ``np.float64 -> np.float32``
* pandas dataframes are reduced using the downcast option of `pd.to_numeric`
to the lowest possible precision.

Returns:
self

"""
self._dataset_compression: Optional[Mapping[str, Any]]

if isinstance(dataset_compression, bool):
if dataset_compression is True:
self._dataset_compression = default_dataset_compression_arg
else:
self._dataset_compression = None
else:
self._dataset_compression = dataset_compression

if self._dataset_compression is not None:
self._dataset_compression = validate_dataset_compression_arg(
self._dataset_compression, memory_limit=memory_limit)

self.dataset, self.InputValidator = self._get_dataset_input_validator(
X_train=X_train,
y_train=y_train,
X_test=X_test,
y_test=y_test,
resampling_strategy=self.resampling_strategy,
resampling_strategy_args=self.resampling_strategy_args,
dataset_name=dataset_name)
dataset_name=dataset_name,
dataset_compression=self._dataset_compression)

return self._search(
dataset=self.dataset,
Expand Down
27 changes: 26 additions & 1 deletion autoPyTorch/data/tabular_feature_validator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import functools
from typing import Dict, List, Optional, Tuple, cast
from logging import Logger
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union, cast

import numpy as np

Expand All @@ -17,6 +18,8 @@
from sklearn.pipeline import make_pipeline

from autoPyTorch.data.base_feature_validator import BaseFeatureValidator, SupportedFeatTypes
from autoPyTorch.data.utils import DatasetDTypeContainerType, reduce_dataset_size_if_too_large
from autoPyTorch.utils.logging_ import PicklableClientLogger


def _create_column_transformer(
Expand Down Expand Up @@ -92,6 +95,15 @@ class TabularFeatureValidator(BaseFeatureValidator):
categorical_columns (List[int]):
List of indices of categorical columns
"""
def __init__(
self,
logger: Optional[Union[PicklableClientLogger, Logger]] = None,
dataset_compression: Optional[Mapping[str, Any]] = None,
) -> None:
self._dataset_compression = dataset_compression
self._precision: Optional[DatasetDTypeContainerType] = None
super().__init__(logger)

@staticmethod
def _comparator(cmp1: str, cmp2: str) -> int:
"""Order so that categorical columns come left and numerical columns come right
Expand Down Expand Up @@ -272,6 +284,19 @@ def transform(
"Please try to manually cast it to a supported "
"numerical or categorical values.")
raise e

if (
(
isinstance(X, np.ndarray) or scipy.sparse.issparse(X) or hasattr(X, 'iloc')
)
and self._dataset_compression is not None
):
if self._precision is not None:
X = X.astype(self._precision)
else:
X = reduce_dataset_size_if_too_large(X, **self._dataset_compression)
self._precision = dict(X.dtypes) if hasattr(X, 'iloc') else X.dtype

return X

def _check_data(
Expand Down
8 changes: 6 additions & 2 deletions autoPyTorch/data/tabular_validator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- encoding: utf-8 -*-
import logging
from typing import Optional, Union
from typing import Any, Mapping, Optional, Union

from autoPyTorch.data.base_validator import BaseInputValidator
from autoPyTorch.data.tabular_feature_validator import TabularFeatureValidator
Expand Down Expand Up @@ -32,9 +32,11 @@ def __init__(
self,
is_classification: bool = False,
logger_port: Optional[int] = None,
dataset_compression: Optional[Mapping[str, Any]] = None,
) -> None:
self.is_classification = is_classification
self.logger_port = logger_port
self.dataset_compression = dataset_compression
if self.logger_port is not None:
self.logger: Union[logging.Logger, PicklableClientLogger] = get_named_client_logger(
name='Validation',
Expand All @@ -43,7 +45,9 @@ def __init__(
else:
self.logger = logging.getLogger('Validation')

self.feature_validator = TabularFeatureValidator(logger=self.logger)
self.feature_validator = TabularFeatureValidator(
dataset_compression=self.dataset_compression,
logger=self.logger)
self.target_validator = TabularTargetValidator(
is_classification=self.is_classification,
logger=self.logger
Expand Down
Loading