Skip to content
Merged
4 changes: 3 additions & 1 deletion autosklearn/automl.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,7 +600,6 @@ def fit(
-------
self
"""
progress_bar = ProgressBar(total=self._time_for_task, disable=self.disable_progress_bar)
if (X_test is not None) ^ (y_test is not None):
raise ValueError("Must provide both X_test and y_test together")

Expand Down Expand Up @@ -630,6 +629,9 @@ def fit(
# By default try to use the TCP logging port or get a new port
self._logger_port = logging.handlers.DEFAULT_TCP_LOGGING_PORT

progress_bar = ProgressBar(
total=self._time_for_task, disable=self.disable_progress_bar
)
# Once we start the logging server, it starts in a new process
# If an error occurs then we want to make sure that we exit cleanly
# and shut it down, else it might hang
Expand Down
6 changes: 5 additions & 1 deletion autosklearn/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,10 @@ def __init__(
Whether autosklearn should process string features. By default the
textpreprocessing is enabled.

disable_progress_bar: bool = False
Whether to disable the progress bar that is displayed in the console
while fitting to the training data.

Attributes
----------
cv_results_ : dict of numpy (masked) ndarrays
Expand Down Expand Up @@ -527,7 +531,7 @@ def build_automl(self):
get_trials_callback=self.get_trials_callback,
dataset_compression=self.dataset_compression,
allow_string_features=self.allow_string_features,
disable_progress_bar=self.disable_progress_bar
disable_progress_bar=self.disable_progress_bar,
)

return automl
Expand Down
6 changes: 6 additions & 0 deletions autosklearn/experimental/askl2.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def __init__(
load_models: bool = True,
dataset_compression: Union[bool, Mapping[str, Any]] = True,
allow_string_features: bool = True,
disable_progress_bar: bool = False,
):

"""
Expand Down Expand Up @@ -284,6 +285,10 @@ def __init__(
load_models : bool, optional (True)
Whether to load the models after fitting Auto-sklearn.

disable_progress_bar: bool = False
Whether to disable the progress bar that is displayed in the console
while fitting to the training data.

Attributes
----------

Expand Down Expand Up @@ -337,6 +342,7 @@ def __init__(
scoring_functions=scoring_functions,
load_models=load_models,
allow_string_features=allow_string_features,
disable_progress_bar=disable_progress_bar,
)

def train_selectors(self, selected_metric=None):
Expand Down
61 changes: 44 additions & 17 deletions autosklearn/util/progress_bar.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,65 @@
import datetime
import time

from threading import Thread
from tqdm import trange

from tqdm import trange # type: ignore


class ProgressBar(Thread):
"""A Thread that displays a tqdm progress bar in the console."""
"""
A Thread that displays a tqdm progress bar in the console.
def __init__(self, total: float, update_interval: float = 1.0, disable: bool = False):
"""
Parameters
----------
total: the total amount that the progress bar should reach
update_interval: reduce this to update the progress bar more frequently
disable: flag that turns on or off the progress bar. If false, then no thread is started or created.
"""
It is specialized to display information relevant to fitting to the training data
with auto-sklearn.
Parameters
----------
total : int
The total amount that should be reached by the progress bar once it finishes
update_interval : float
Specifies how frequently the progress bar is updated (in seconds)
disable : bool
Turns on or off the progress bar. If True, this thread won't be started or
initialized.
"""

def __init__(
self,
total: int,
update_interval: float = 1.0,
disable: bool = False,
):
self.disable = disable
if not disable:
super().__init__(name="_progressbar_")
self.total = total
self.update_interval = update_interval
self.terminated: bool = False
# start this thread
self.start()

def run(self):
def run(self) -> None:
"""
Overrides the run method of Thread. It displays a tqdm progress bar in the
console with useful descriptions about the task.
"""
if not self.disable:
for _ in trange(self.total, colour="green"):
for _ in trange(
self.total,
colour="green",
desc="Fitting to the training data",
postfix=f"The total time budget for this task is"
f" {datetime.timedelta(seconds=self.total)}",
):
if not self.terminated:
time.sleep(self.update_interval)
else:
pass # max out the bar
print("Finishing up the task...")

def stop(self):
def stop(self) -> None:
"""
Terminates the thread.
"""
if not self.disable:
self.terminated = True
super().join()

3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@ pyyaml
pandas>=1.0
liac-arff
threadpoolctl
tqdm

ConfigSpace>=0.4.21,<0.5
pynisher>=0.6.3,<0.7
pyrfr>=0.8.1,<0.9
smac>=1.2,<1.3
smac>=1.2,<1.3