Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
239 changes: 162 additions & 77 deletions autoPyTorch/api/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,89 +551,142 @@ def _do_dummy_prediction(self, num_run: int) -> None:
% (str(status), str(additional_info))
)

def _do_traditional_prediction(self, num_run: int, time_for_traditional: int) -> int:
def _do_traditional_prediction(self, num_run: int, time_left: int, func_eval_time_limit: int
) -> int:
"""
Fits traditional machine learning algorithms to the provided dataset, while
complying with time resource allocation.

This method currently only supports classification.

Args:
num_run: (int)
An identifier to indicate the current machine learning algorithm
being processed
time_left: (int)
Hard limit on how many machine learning algorithms can be fit. Depending on how
fast a traditional machine learning algorithm trains, it will allow multiple
models to be fitted.
func_eval_time_limit: (int)
Maximum training time each algorithm is allowed to take, during training

Returns:
num_run: (int)
The incremented identifier index. This depends on how many machine learning
models where fitted.
"""

# Mypy Checkings -- Traditional prediction is only called for search
# where the following objects are created
assert self._metric is not None
assert self._logger is not None
assert self._dask_client is not None

self._logger.info("Starting to create dummy predictions.")

memory_limit = self._memory_limit
if memory_limit is not None:
memory_limit = int(math.ceil(memory_limit))
available_classifiers = get_available_classifiers()
dask_futures = list()
time_for_traditional_classifier_sec = int(time_for_traditional / len(available_classifiers))
dask_futures = []

total_number_classifiers = len(available_classifiers) + num_run
for n_r, classifier in enumerate(available_classifiers, start=num_run):

# Only launch a task if there is time
start_time = time.time()
scenario_mock = unittest.mock.Mock()
scenario_mock.wallclock_limit = time_for_traditional_classifier_sec
# This stats object is a hack - maybe the SMAC stats object should
# already be generated here!
stats = Stats(scenario_mock)
stats.start_timing()
ta = ExecuteTaFuncWithQueue(
backend=self._backend,
seed=self.seed,
metric=self._metric,
logger_port=self._logger_port,
cost_for_crash=get_cost_of_crash(self._metric),
abort_on_first_run_crash=False,
initial_num_run=num_run,
stats=stats,
memory_limit=memory_limit,
disable_file_output=True if len(self._disable_file_output) > 0 else False,
all_supported_metrics=self._all_supported_metrics
)
dask_futures.append((classifier, self._dask_client.submit(ta.run, config=classifier,
cutoff=time_for_traditional_classifier_sec)))
if time_left >= func_eval_time_limit:
self._logger.info(f"{n_r}: Started fitting {classifier} with cutoff={func_eval_time_limit}")
scenario_mock = unittest.mock.Mock()
scenario_mock.wallclock_limit = time_left
# This stats object is a hack - maybe the SMAC stats object should
# already be generated here!
stats = Stats(scenario_mock)
stats.start_timing()
ta = ExecuteTaFuncWithQueue(
backend=self._backend,
seed=self.seed,
metric=self._metric,
logger_port=self._logger_port,
cost_for_crash=get_cost_of_crash(self._metric),
abort_on_first_run_crash=False,
initial_num_run=n_r,
stats=stats,
memory_limit=memory_limit,
disable_file_output=True if len(self._disable_file_output) > 0 else False,
all_supported_metrics=self._all_supported_metrics
)
dask_futures.append([
classifier,
self._dask_client.submit(
ta.run, config=classifier,
cutoff=func_eval_time_limit,
)
])

# Increment the launched job index
num_run = n_r

if len(dask_futures) >= self.n_jobs:

# How many workers to wait before starting fitting the next iteration
workers_to_wait = 1
if n_r >= total_number_classifiers - 1 or time_left <= func_eval_time_limit:
# If on the last iteration, flush out all tasks
workers_to_wait = len(dask_futures)

while workers_to_wait >= 1:
workers_to_wait -= 1
# We launch dask jobs only when there are resources available.
# This allow us to control time allocation properly, and early terminate
# the traditional machine learning pipeline
cls, future = dask_futures.pop(0)
status, cost, runtime, additional_info = future.result()
if status == StatusType.SUCCESS:
self._logger.info(
f"Fitting {cls} took {runtime}s, performance:{cost}/{additional_info}")
else:
if additional_info.get('exitcode') == -6:
self._logger.error(
"Traditional prediction for %s failed with run state %s. "
"The error suggests that the provided memory limits were too tight. Please "
"increase the 'ml_memory_limit' and try again. If this does not solve your "
"problem, please open an issue and paste the additional output. "
"Additional output: %s.",
cls, str(status), str(additional_info),
)
else:
self._logger.error(
"Traditional prediction for %s failed with run state %s and additional output: %s.",
cls, str(status), str(additional_info),
)

# In the case of a serial execution, calling submit halts the run for a resource
# dynamically adjust time in this case
time_for_traditional_classifier_sec -= int(time.time() - start_time)
num_run = n_r
time_left -= int(time.time() - start_time)

# Exit if no more time is available for a new classifier
if time_left < func_eval_time_limit:
self._logger.warning("Not enough time to fit all traditional machine learning models."
"Please consider increasing the run time to further improve performance.")
break

for (classifier, future) in dask_futures:
status, cost, runtime, additional_info = future.result()
if status == StatusType.SUCCESS:
self._logger.info("Finished creating predictions for {}".format(classifier))
else:
if additional_info.get('exitcode') == -6:
self._logger.error(
"Traditional prediction for %s failed with run state %s. "
"The error suggests that the provided memory limits were too tight. Please "
"increase the 'ml_memory_limit' and try again. If this does not solve your "
"problem, please open an issue and paste the additional output. "
"Additional output: %s.",
classifier, str(status), str(additional_info),
)
else:
# TODO: add check for timeout, and provide feedback to user to consider increasing the time limit
self._logger.error(
"Traditional prediction for %s failed with run state %s and additional output: %s.",
classifier, str(status), str(additional_info),
)
return num_run

def _search(
self,
optimize_metric: str,
dataset: BaseDataset,
budget_type: Optional[str] = None,
budget: Optional[float] = None,
total_walltime_limit: int = 100,
func_eval_time_limit: int = 60,
traditional_per_total_budget: float = 0.1,
memory_limit: Optional[int] = 4096,
smac_scenario_args: Optional[Dict[str, Any]] = None,
get_smac_object_callback: Optional[Callable] = None,
all_supported_metrics: bool = True,
precision: int = 32,
disable_file_output: List = [],
load_models: bool = True,
self,
optimize_metric: str,
dataset: BaseDataset,
budget_type: Optional[str] = None,
budget: Optional[float] = None,
total_walltime_limit: int = 100,
func_eval_time_limit: Optional[int] = None,
enable_traditional_pipeline: bool = True,
memory_limit: Optional[int] = 4096,
smac_scenario_args: Optional[Dict[str, Any]] = None,
get_smac_object_callback: Optional[Callable] = None,
all_supported_metrics: bool = True,
precision: int = 32,
disable_file_output: List = [],
load_models: bool = True,
) -> 'BaseTask':
"""
Search for the best pipeline configuration for the given dataset.
Expand All @@ -660,16 +713,20 @@ def _search(
in seconds for the search of appropriate models.
By increasing this value, autopytorch has a higher
chance of finding better models.
func_eval_time_limit (int), (default=60): Time limit
func_eval_time_limit (int), (default=None): Time limit
for a single call to the machine learning model.
Model fitting will be terminated if the machine
learning algorithm runs over the time limit. Set
this value high enough so that typical machine
learning algorithms can be fit on the training
data.
traditional_per_total_budget (float), (default=0.1):
Percent of total walltime to be allocated for
running traditional classifiers.
enable_traditional_pipeline (bool), (default=True):
We fit traditional machine learning algorithms
(LightGBM, CatBoost, RandomForest, ExtraTrees, KNN, SVM)
prior building PyTorch Neural Networks. You can disable this
feature by turning this flag to False. All machine learning
algorithms that are fitted during search() are considered for
ensemble building.
memory_limit (Optional[int]), (default=4096): Memory
limit in MB for the machine learning algorithm. autopytorch
will stop fitting the machine learning algorithm if it tries
Expand Down Expand Up @@ -755,6 +812,28 @@ def _search(
else:
self._is_dask_client_internally_created = False

# Handle time resource allocation
elapsed_time = self._stopwatch.wall_elapsed(experiment_task_name)
time_left_for_modelfit = int(max(0, total_walltime_limit - elapsed_time))
if func_eval_time_limit is None or func_eval_time_limit > time_left_for_modelfit:
self._logger.warning(
'Time limit for a single run is higher than total time '
'limit. Capping the limit for a single run to the total '
'time given to SMAC (%f)' % time_left_for_modelfit
)
func_eval_time_limit = time_left_for_modelfit

# Make sure that at least 2 models are created for the ensemble process
num_models = time_left_for_modelfit // func_eval_time_limit
if num_models < 2:
func_eval_time_limit = time_left_for_modelfit // 2
self._logger.warning(
"Capping the func_eval_time_limit to {} to have "
"time for a least 2 models to ensemble.".format(
func_eval_time_limit
)
)

# ============> Run dummy predictions
num_run = 1
dummy_task_name = 'runDummy'
Expand All @@ -764,16 +843,22 @@ def _search(

# ============> Run traditional ml

traditional_task_name = 'runTraditional'
self._stopwatch.start_task(traditional_task_name)
elapsed_time = self._stopwatch.wall_elapsed(self.dataset_name)
time_for_traditional = int(traditional_per_total_budget * max(0, (self._time_for_task - elapsed_time)))
if time_for_traditional <= 0:
if traditional_per_total_budget > 0:
raise ValueError("Not enough time allocated to run traditional algorithms")
elif traditional_per_total_budget != 0:
num_run = self._do_traditional_prediction(num_run=num_run + 1, time_for_traditional=time_for_traditional)
self._stopwatch.stop_task(traditional_task_name)
if enable_traditional_pipeline:
if STRING_TO_TASK_TYPES[self.task_type] in REGRESSION_TASKS:
self._logger.warning("Traditional Pipeline is not enabled for regression. Skipping...")
else:
traditional_task_name = 'runTraditional'
self._stopwatch.start_task(traditional_task_name)
elapsed_time = self._stopwatch.wall_elapsed(self.dataset_name)
# We want time for at least 1 Neural network in SMAC
time_for_traditional = int(
self._time_for_task - elapsed_time - func_eval_time_limit
)
num_run = self._do_traditional_prediction(
num_run=num_run + 1, func_eval_time_limit=func_eval_time_limit,
time_left=time_for_traditional,
)
self._stopwatch.stop_task(traditional_task_name)

# ============> Starting ensemble
elapsed_time = self._stopwatch.wall_elapsed(self.dataset_name)
Expand Down
18 changes: 11 additions & 7 deletions autoPyTorch/api/tabular_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,8 @@ def search(
budget_type: Optional[str] = None,
budget: Optional[float] = None,
total_walltime_limit: int = 100,
func_eval_time_limit: int = 60,
traditional_per_total_budget: float = 0.1,
func_eval_time_limit: Optional[int] = None,
enable_traditional_pipeline: bool = True,
memory_limit: Optional[int] = 4096,
smac_scenario_args: Optional[Dict[str, Any]] = None,
get_smac_object_callback: Optional[Callable] = None,
Expand Down Expand Up @@ -156,16 +156,20 @@ def search(
in seconds for the search of appropriate models.
By increasing this value, autopytorch has a higher
chance of finding better models.
func_eval_time_limit (int), (default=60): Time limit
func_eval_time_limit (int), (default=None): Time limit
for a single call to the machine learning model.
Model fitting will be terminated if the machine
learning algorithm runs over the time limit. Set
this value high enough so that typical machine
learning algorithms can be fit on the training
data.
traditional_per_total_budget (float), (default=0.1):
Percent of total walltime to be allocated for
running traditional classifiers.
enable_traditional_pipeline (bool), (default=True):
We fit traditional machine learning algorithms
(LightGBM, CatBoost, RandomForest, ExtraTrees, KNN, SVM)
prior building PyTorch Neural Networks. You can disable this
feature by turning this flag to False. All machine learning
algorithms that are fitted during search() are considered for
ensemble building.
memory_limit (Optional[int]), (default=4096): Memory
limit in MB for the machine learning algorithm. autopytorch
will stop fitting the machine learning algorithm if it tries
Expand Down Expand Up @@ -229,7 +233,7 @@ def search(
budget=budget,
total_walltime_limit=total_walltime_limit,
func_eval_time_limit=func_eval_time_limit,
traditional_per_total_budget=traditional_per_total_budget,
enable_traditional_pipeline=enable_traditional_pipeline,
memory_limit=memory_limit,
smac_scenario_args=smac_scenario_args,
get_smac_object_callback=get_smac_object_callback,
Expand Down
Loading