diff --git a/autosklearn/automl.py b/autosklearn/automl.py index 93fde84330..1d37cb2321 100644 --- a/autosklearn/automl.py +++ b/autosklearn/automl.py @@ -652,6 +652,7 @@ def fit( # space self._backend.save_start_time(self._seed) + progress_bar.start() self._stopwatch = StopWatch() # Make sure that input is valid @@ -970,7 +971,7 @@ def fit( self._logger.exception(e) raise e finally: - progress_bar.stop() + progress_bar.join() self._fit_cleanup() self.fitted = True diff --git a/autosklearn/util/progress_bar.py b/autosklearn/util/progress_bar.py index 7ccd3bc153..c1eb3139f8 100644 --- a/autosklearn/util/progress_bar.py +++ b/autosklearn/util/progress_bar.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Any import datetime @@ -10,22 +12,45 @@ class ProgressBar(Thread): """A Thread that displays a tqdm progress bar in the console. - It is specialized to display information relevant to fitting to the training data - with auto-sklearn. + Treat this class as an ordinary thread. So to display a progress bar, + call start() on an instance of this class. To wait for the thread to + terminate call join(), which will max out the progress bar, + therefore terminate this thread immediately. Parameters ---------- total : int - The total amount that should be reached by the progress bar once it finishes - update_interval : float - Specifies how frequently the progress bar is updated (in seconds) - disable : bool - Turns on or off the progress bar. If True, this thread won't be started or - initialized. - kwargs : Any + The total amount that should be reached by the progress bar once it finishes. + update_interval : float, default=1.0 + Specifies how frequently the progress bar is updated (in seconds). + disable : bool, default=False + Turns on or off the progress bar. If True, this thread does not get + initialized and won't be started if start() is called. + tqdm_kwargs : Any, optional Keyword arguments that are passed into tqdm's constructor. Refer to: - `tqdm `_. Note that postfix can not be - specified in the kwargs since it is already passed into tqdm by this class. + `tqdm `_ for a list of parameters that + tqdm accepts. Note that 'postfix' cannot be specified in the kwargs since it is + already passed into tqdm by this class. + + Examples + -------- + + .. code:: python + + progress_bar = ProgressBar( + total=10, + desc="Executing code that runs for 10 seconds", + colour="green", + ) + # colour is a tqdm parameter passed as a tqdm_kwargs + try: + progress_bar.start() + # some code that runs for 10 seconds + except SomeException: + # something went wrong + finally: + progress_bar.join() + # perform some cleanup """ def __init__( @@ -33,7 +58,7 @@ def __init__( total: int, update_interval: float = 1.0, disable: bool = False, - **kwargs: Any, + **tqdm_kwargs: Any, ): self.disable = disable if not disable: @@ -41,28 +66,27 @@ def __init__( self.total = total self.update_interval = update_interval self.terminated: bool = False - self.kwargs = kwargs - # start this thread - self.start() + self.tqdm_kwargs = tqdm_kwargs - def run(self) -> None: - """Display a tqdm progress bar in the console. + def start(self) -> None: + """Start a new thread that calls the run() method.""" + if not self.disable: + super().start() - Additionally, it shows useful information related to the task. This method - overrides the run method of Thread. - """ + def run(self) -> None: + """Display a tqdm progress bar in the console.""" if not self.disable: for _ in trange( self.total, postfix=f"The total time budget for this task is " f"{datetime.timedelta(seconds=self.total)}", - **self.kwargs, + **self.tqdm_kwargs, ): if not self.terminated: time.sleep(self.update_interval) - def stop(self) -> None: - """Terminates the thread.""" + def join(self, timeout: float | None = None) -> None: + """Maxes out the progress bar and thereby terminating this thread.""" if not self.disable: self.terminated = True - super().join() + super().join(timeout)