Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
e147851
TEST: commit a dummy file
nabenabe0928 Feb 2, 2021
2e9dbc6
REFACTORING (base_dataset.py): The outputs of all the modifications h…
nabenabe0928 Feb 3, 2021
9ae9dcf
REFACTORING (resampling strategy.py): Implanted the functions into CR…
nabenabe0928 Feb 3, 2021
f3c5749
REFACTORING (resampling_strategy.py): Implant is_stratified into each…
nabenabe0928 Feb 3, 2021
741e93c
REFACTORING (resampling_strategy, base_dataset, time_series_dataset):…
nabenabe0928 Feb 3, 2021
70d99f9
fixed small issue in get_holdout_validators in HoldOutFuncs
nabenabe0928 Feb 3, 2021
f0cf3a7
(REFACTORING in utils.common.py): Added raise error that happens when…
nabenabe0928 Feb 4, 2021
3a4074e
[REFACTORING (resampling_strategy.py)]: added raise error if the chil…
nabenabe0928 Feb 5, 2021
f7984d9
[REFACTORING]: to random_state (I will commit this change one by one …
nabenabe0928 Feb 5, 2021
d7011c0
As a record. Before the change becomes too much
nabenabe0928 Feb 5, 2021
088d29f
[REFACTORING (base_dataset.py)]: deleted DEFAULT_RESAMPLING_PARAMETERS
nabenabe0928 Feb 5, 2021
cf9d463
[BUG FIXED] removed bugs to let test_tabular_dataset.py work, but not…
nabenabe0928 Feb 5, 2021
988adb9
[REFACTORING] United HoldOutVal -> HoldOut
nabenabe0928 Feb 5, 2021
1eebd33
[Formatting]: formatted resampling_strategy.py and base_dataset.py
nabenabe0928 Feb 9, 2021
16310d7
[REFACTORING]: unite to splitting_type and splitting_params to be abl…
nabenabe0928 Feb 10, 2021
2b47eba
[DEBUG] to be able to run test_evaluation.py
nabenabe0928 Feb 10, 2021
f9ff334
[DEBUG] to be able to run test_train_evaluator.py and test_abstract_e…
nabenabe0928 Feb 10, 2021
6cdb589
[REFACTORING]: united var name: resampling_strategy -> splitting_type…
nabenabe0928 Feb 16, 2021
056fcbb
[REFACTORING]: adapted to the string specification of splitting type
nabenabe0928 Feb 16, 2021
7fa00b4
[REFACTORING: utils/pipeline.py] Separate the erro processing and enh…
nabenabe0928 Feb 16, 2021
37417fa
[REFACTORING: utils.pipeline.py] Created task type class and reduced …
nabenabe0928 Feb 17, 2021
c5733e0
[REFACTORING: constants.py and utils.pipeline.py] Made it possible to…
nabenabe0928 Feb 17, 2021
ce1f638
[New Feature]: added BaseDict to utils.common.py
nabenabe0928 Feb 17, 2021
5a6ea8b
[MEMO]: ignore this commit (added TODO lines)
nabenabe0928 Feb 17, 2021
ebd4bf2
[REFACTORING]: Added Base class for each task type class
nabenabe0928 Feb 18, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,7 @@ dask-worker-space/
# Test output
tmp/
.tmp_evaluation

# Private file
grep.py
memo.txt
30 changes: 15 additions & 15 deletions autoPyTorch/api/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
STRING_TO_TASK_TYPES,
)
from autoPyTorch.datasets.base_dataset import BaseDataset
from autoPyTorch.datasets.resampling_strategy import CrossValTypes, HoldoutValTypes
from autoPyTorch.datasets.train_val_split import CrossValTypes, HoldOutTypes
from autoPyTorch.ensemble.ensemble_builder import EnsembleBuilderManager
from autoPyTorch.ensemble.ensemble_selection import EnsembleSelection
from autoPyTorch.ensemble.singlebest_ensemble import SingleBest
Expand Down Expand Up @@ -175,8 +175,8 @@ def __init__(
# By default try to use the TCP logging port or get a new port
self._logger_port = logging.handlers.DEFAULT_TCP_LOGGING_PORT

# Store the resampling strategy from the dataset, to load models as needed
self.resampling_strategy = None # type: Optional[Union[CrossValTypes, HoldoutValTypes]]
# Store the splitting type from the dataset, to load models as needed
self.splitting_type = None # type: Optional[Union[CrossValTypes, HoldOutTypes]]

self.stop_logging_server = None # type: Optional[multiprocessing.synchronize.Event]

Expand Down Expand Up @@ -398,21 +398,21 @@ def _close_dask_client(self) -> None:
self._is_dask_client_internally_created = False
del self._is_dask_client_internally_created

def _load_models(self, resampling_strategy: Optional[Union[CrossValTypes, HoldoutValTypes]]
def _load_models(self, splitting_type: Optional[Union[CrossValTypes, HoldOutTypes]]
) -> bool:

"""
Loads the models saved in the temporary directory
during the smac run and the final ensemble created
Args:
resampling_strategy (Union[CrossValTypes, HoldoutValTypes]): resampling strategy used to split the data
splitting_type (Union[CrossValTypes, HoldOutTypes]): splitting type used to split the data
and to validate the performance of a candidate pipeline

Returns:
None
"""
if resampling_strategy is None:
raise ValueError("Resampling strategy is needed to determine what models to load")
if splitting_type is None:
raise ValueError("Splitting type is needed to determine what models to load")
self.ensemble_ = self._backend.load_ensemble(self.seed)

# If no ensemble is loaded, try to get the best performing model
Expand All @@ -422,10 +422,10 @@ def _load_models(self, resampling_strategy: Optional[Union[CrossValTypes, Holdou
if self.ensemble_:
identifiers = self.ensemble_.get_selected_model_identifiers()
self.models_ = self._backend.load_models_by_identifiers(identifiers)
if isinstance(resampling_strategy, CrossValTypes):
if isinstance(splitting_type, CrossValTypes):
self.cv_models_ = self._backend.load_cv_models_by_identifiers(identifiers)

if isinstance(resampling_strategy, CrossValTypes):
if isinstance(splitting_type, CrossValTypes):
if len(self.cv_models_) == 0:
raise ValueError('No models fitted!')

Expand Down Expand Up @@ -705,7 +705,7 @@ def search(
dataset_properties = dataset.get_dataset_properties(dataset_requirements)
self._stopwatch.start_task(experiment_task_name)
self.dataset_name = dataset.dataset_name
self.resampling_strategy = dataset.resampling_strategy
self.splitting_type = dataset.splitting_type
self._logger = self._get_logger(self.dataset_name)
self._all_supported_metrics = all_supported_metrics
self._disable_file_output = disable_file_output
Expand Down Expand Up @@ -869,7 +869,7 @@ def search(

if load_models:
self._logger.info("Loading models...")
self._load_models(dataset.resampling_strategy)
self._load_models(dataset.splitting_type)
self._logger.info("Finished loading models...")

# Clean up the logger
Expand Down Expand Up @@ -927,7 +927,7 @@ def refit(
})
X.update({**self.pipeline_options, **budget_config})
if self.models_ is None or len(self.models_) == 0 or self.ensemble_ is None:
self._load_models(dataset.resampling_strategy)
self._load_models(dataset.splitting_type)

# Refit is not applicable when ensemble_size is set to zero.
if self.ensemble_ is None:
Expand Down Expand Up @@ -1025,17 +1025,17 @@ def predict(
if self._logger is None:
self._logger = self._get_logger("Predict-Logger")

if self.ensemble_ is None and not self._load_models(self.resampling_strategy):
if self.ensemble_ is None and not self._load_models(self.splitting_type):
raise ValueError("No ensemble found. Either fit has not yet "
"been called or no ensemble was fitted")

# Mypy assert
assert self.ensemble_ is not None, "Load models should error out if no ensemble"
self.ensemble_ = cast(Union[SingleBest, EnsembleSelection], self.ensemble_)

if isinstance(self.resampling_strategy, HoldoutValTypes):
if isinstance(self.splitting_type, HoldOutTypes):
models = self.models_
elif isinstance(self.resampling_strategy, CrossValTypes):
elif isinstance(self.splitting_type, CrossValTypes):
models = self.cv_models_

all_predictions = joblib.Parallel(n_jobs=n_jobs)(
Expand Down
86 changes: 86 additions & 0 deletions autoPyTorch/constants.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,87 @@
"""Constant numbers for this package

TODO:
* Makes this file nicer
* transfer to proper locations e.g. create task directory?
"""

from enum import Enum
from autoPyTorch.pipeline.image_classification import ImageClassificationPipeline
from autoPyTorch.pipeline.tabular_classification import TabularClassificationPipeline
from autoPyTorch.pipeline.tabular_regression import TabularRegressionPipeline
import abc
from typing import Union


SupportedPipelines = Union[ImageClassificationPipeline,
TabularClassificationPipeline,
TabularRegressionPipeline]


class BaseTaskTypes(metaclass=abc.ABCMeta):
@abc.abstractmethod
def is_supported(self) -> bool:
raise NotImplementedError

@abc.abstractmethod
def task_name(self) -> str:
raise NotImplementedError

@abc.abstractmethod
def dataset_type(self) -> str:
raise NotImplementedError

@abc.abstractmethod
def pipeline(self) -> SupportedPipelines:
raise NotImplementedError


class RegressionTypes(Enum, BaseTaskTypes):
tabular = TabularRegressionPipeline
image = None
time_series = None

def is_supported(self) -> bool:
return self.value is not None

def task_name(self) -> str:
return 'regressor'

def dataset_type(self) -> str:
return self.name

def pipeline(self) -> SupportedPipelines:
if not self.is_supported():
raise ValueError(f"{self.name} is not supported pipeline.")

return self.value


class ClassificationTypes(Enum, BaseTaskTypes):
tabular = TabularClassificationPipeline
image = ImageClassificationPipeline
time_series = None

def is_supported(self) -> bool:
return self.value is not None

def task_name(self) -> str:
return 'classifier'

def dataset_type(self) -> str:
return self.name

def pipeline(self) -> SupportedPipelines:
if not self.is_supported():
raise ValueError(f"{self.name} is not supported pipeline.")

return self.value


SupportedTaskTypes = (RegressionTypes, ClassificationTypes)


"""TODO: remove these variables
TABULAR_CLASSIFICATION = 1
IMAGE_CLASSIFICATION = 2
TABULAR_REGRESSION = 3
Expand Down Expand Up @@ -27,6 +111,8 @@
'image_regression': IMAGE_REGRESSION,
'time_series_classification': TIMESERIES_CLASSIFICATION,
'time_series_regression': TIMESERIES_REGRESSION}
"""


# Output types have been defined as in scikit-learn type_of_target
# (https://scikit-learn.org/stable/modules/generated/sklearn.utils.multiclass.type_of_target.html)
Expand Down
Loading