diff --git a/autosklearn/automl.py b/autosklearn/automl.py
index 93fde84330..1d37cb2321 100644
--- a/autosklearn/automl.py
+++ b/autosklearn/automl.py
@@ -652,6 +652,7 @@ def fit(
# space
self._backend.save_start_time(self._seed)
+ progress_bar.start()
self._stopwatch = StopWatch()
# Make sure that input is valid
@@ -970,7 +971,7 @@ def fit(
self._logger.exception(e)
raise e
finally:
- progress_bar.stop()
+ progress_bar.join()
self._fit_cleanup()
self.fitted = True
diff --git a/autosklearn/util/progress_bar.py b/autosklearn/util/progress_bar.py
index 7ccd3bc153..c1eb3139f8 100644
--- a/autosklearn/util/progress_bar.py
+++ b/autosklearn/util/progress_bar.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
from typing import Any
import datetime
@@ -10,22 +12,45 @@
class ProgressBar(Thread):
"""A Thread that displays a tqdm progress bar in the console.
- It is specialized to display information relevant to fitting to the training data
- with auto-sklearn.
+ Treat this class as an ordinary thread. So to display a progress bar,
+ call start() on an instance of this class. To wait for the thread to
+ terminate call join(), which will max out the progress bar,
+ therefore terminate this thread immediately.
Parameters
----------
total : int
- The total amount that should be reached by the progress bar once it finishes
- update_interval : float
- Specifies how frequently the progress bar is updated (in seconds)
- disable : bool
- Turns on or off the progress bar. If True, this thread won't be started or
- initialized.
- kwargs : Any
+ The total amount that should be reached by the progress bar once it finishes.
+ update_interval : float, default=1.0
+ Specifies how frequently the progress bar is updated (in seconds).
+ disable : bool, default=False
+ Turns on or off the progress bar. If True, this thread does not get
+ initialized and won't be started if start() is called.
+ tqdm_kwargs : Any, optional
Keyword arguments that are passed into tqdm's constructor. Refer to:
- `tqdm `_. Note that postfix can not be
- specified in the kwargs since it is already passed into tqdm by this class.
+ `tqdm `_ for a list of parameters that
+ tqdm accepts. Note that 'postfix' cannot be specified in the kwargs since it is
+ already passed into tqdm by this class.
+
+ Examples
+ --------
+
+ .. code:: python
+
+ progress_bar = ProgressBar(
+ total=10,
+ desc="Executing code that runs for 10 seconds",
+ colour="green",
+ )
+ # colour is a tqdm parameter passed as a tqdm_kwargs
+ try:
+ progress_bar.start()
+ # some code that runs for 10 seconds
+ except SomeException:
+ # something went wrong
+ finally:
+ progress_bar.join()
+ # perform some cleanup
"""
def __init__(
@@ -33,7 +58,7 @@ def __init__(
total: int,
update_interval: float = 1.0,
disable: bool = False,
- **kwargs: Any,
+ **tqdm_kwargs: Any,
):
self.disable = disable
if not disable:
@@ -41,28 +66,27 @@ def __init__(
self.total = total
self.update_interval = update_interval
self.terminated: bool = False
- self.kwargs = kwargs
- # start this thread
- self.start()
+ self.tqdm_kwargs = tqdm_kwargs
- def run(self) -> None:
- """Display a tqdm progress bar in the console.
+ def start(self) -> None:
+ """Start a new thread that calls the run() method."""
+ if not self.disable:
+ super().start()
- Additionally, it shows useful information related to the task. This method
- overrides the run method of Thread.
- """
+ def run(self) -> None:
+ """Display a tqdm progress bar in the console."""
if not self.disable:
for _ in trange(
self.total,
postfix=f"The total time budget for this task is "
f"{datetime.timedelta(seconds=self.total)}",
- **self.kwargs,
+ **self.tqdm_kwargs,
):
if not self.terminated:
time.sleep(self.update_interval)
- def stop(self) -> None:
- """Terminates the thread."""
+ def join(self, timeout: float | None = None) -> None:
+ """Maxes out the progress bar and thereby terminating this thread."""
if not self.disable:
self.terminated = True
- super().join()
+ super().join(timeout)