diff --git a/autosklearn/automl.py b/autosklearn/automl.py
index 93fde84330..ab70dc141f 100644
--- a/autosklearn/automl.py
+++ b/autosklearn/automl.py
@@ -652,6 +652,7 @@ def fit(
# space
self._backend.save_start_time(self._seed)
+ progress_bar.start()
self._stopwatch = StopWatch()
# Make sure that input is valid
@@ -970,7 +971,7 @@ def fit(
self._logger.exception(e)
raise e
finally:
- progress_bar.stop()
+ progress_bar.join()
self._fit_cleanup()
self.fitted = True
@@ -1920,15 +1921,17 @@ def cv_results_(self):
metric_dict[metric.name] = []
metric_mask[metric.name] = []
+ model_ids = []
mean_fit_time = []
params = []
status = []
budgets = []
- for run_key in self.runhistory_.data:
- run_value = self.runhistory_.data[run_key]
+ for run_key, run_value in self.runhistory_.data.items():
config_id = run_key.config_id
config = self.runhistory_.ids_config[config_id]
+ if run_value.additional_info and "num_run" in run_value.additional_info:
+ model_ids.append(run_value.additional_info["num_run"])
s = run_value.status
if s == StatusType.SUCCESS:
@@ -1989,6 +1992,8 @@ def cv_results_(self):
metric_dict[metric.name].append(metric_value)
metric_mask[metric.name].append(mask_value)
+ results["model_ids"] = model_ids
+
if len(self._metrics) == 1:
results["mean_test_score"] = np.array(metric_dict[self._metrics[0].name])
rank_order = -1 * self._metrics[0]._sign * results["mean_test_score"]
@@ -2164,14 +2169,11 @@ def show_models(self) -> dict[int, Any]:
warnings.warn("No ensemble found. Returning empty dictionary.")
return ensemble_dict
- def has_key(rv, key):
- return rv.additional_info and key in rv.additional_info
-
table_dict = {}
- for run_key, run_val in self.runhistory_.data.items():
- if has_key(run_val, "num_run"):
- model_id = run_val.additional_info["num_run"]
- table_dict[model_id] = {"model_id": model_id, "cost": run_val.cost}
+ for run_key, run_value in self.runhistory_.data.items():
+ if run_value.additional_info and "num_run" in run_value.additional_info:
+ model_id = run_value.additional_info["num_run"]
+ table_dict[model_id] = {"model_id": model_id, "cost": run_value.cost}
# Checking if the dictionary is empty
if not table_dict:
diff --git a/autosklearn/util/progress_bar.py b/autosklearn/util/progress_bar.py
index 7ccd3bc153..c1eb3139f8 100644
--- a/autosklearn/util/progress_bar.py
+++ b/autosklearn/util/progress_bar.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
from typing import Any
import datetime
@@ -10,22 +12,45 @@
class ProgressBar(Thread):
"""A Thread that displays a tqdm progress bar in the console.
- It is specialized to display information relevant to fitting to the training data
- with auto-sklearn.
+ Treat this class as an ordinary thread. So to display a progress bar,
+ call start() on an instance of this class. To wait for the thread to
+ terminate call join(), which will max out the progress bar,
+ therefore terminate this thread immediately.
Parameters
----------
total : int
- The total amount that should be reached by the progress bar once it finishes
- update_interval : float
- Specifies how frequently the progress bar is updated (in seconds)
- disable : bool
- Turns on or off the progress bar. If True, this thread won't be started or
- initialized.
- kwargs : Any
+ The total amount that should be reached by the progress bar once it finishes.
+ update_interval : float, default=1.0
+ Specifies how frequently the progress bar is updated (in seconds).
+ disable : bool, default=False
+ Turns on or off the progress bar. If True, this thread does not get
+ initialized and won't be started if start() is called.
+ tqdm_kwargs : Any, optional
Keyword arguments that are passed into tqdm's constructor. Refer to:
- `tqdm `_. Note that postfix can not be
- specified in the kwargs since it is already passed into tqdm by this class.
+ `tqdm `_ for a list of parameters that
+ tqdm accepts. Note that 'postfix' cannot be specified in the kwargs since it is
+ already passed into tqdm by this class.
+
+ Examples
+ --------
+
+ .. code:: python
+
+ progress_bar = ProgressBar(
+ total=10,
+ desc="Executing code that runs for 10 seconds",
+ colour="green",
+ )
+ # colour is a tqdm parameter passed as a tqdm_kwargs
+ try:
+ progress_bar.start()
+ # some code that runs for 10 seconds
+ except SomeException:
+ # something went wrong
+ finally:
+ progress_bar.join()
+ # perform some cleanup
"""
def __init__(
@@ -33,7 +58,7 @@ def __init__(
total: int,
update_interval: float = 1.0,
disable: bool = False,
- **kwargs: Any,
+ **tqdm_kwargs: Any,
):
self.disable = disable
if not disable:
@@ -41,28 +66,27 @@ def __init__(
self.total = total
self.update_interval = update_interval
self.terminated: bool = False
- self.kwargs = kwargs
- # start this thread
- self.start()
+ self.tqdm_kwargs = tqdm_kwargs
- def run(self) -> None:
- """Display a tqdm progress bar in the console.
+ def start(self) -> None:
+ """Start a new thread that calls the run() method."""
+ if not self.disable:
+ super().start()
- Additionally, it shows useful information related to the task. This method
- overrides the run method of Thread.
- """
+ def run(self) -> None:
+ """Display a tqdm progress bar in the console."""
if not self.disable:
for _ in trange(
self.total,
postfix=f"The total time budget for this task is "
f"{datetime.timedelta(seconds=self.total)}",
- **self.kwargs,
+ **self.tqdm_kwargs,
):
if not self.terminated:
time.sleep(self.update_interval)
- def stop(self) -> None:
- """Terminates the thread."""
+ def join(self, timeout: float | None = None) -> None:
+ """Maxes out the progress bar and thereby terminating this thread."""
if not self.disable:
self.terminated = True
- super().join()
+ super().join(timeout)