diff --git a/autosklearn/automl.py b/autosklearn/automl.py index 93fde84330..ab70dc141f 100644 --- a/autosklearn/automl.py +++ b/autosklearn/automl.py @@ -652,6 +652,7 @@ def fit( # space self._backend.save_start_time(self._seed) + progress_bar.start() self._stopwatch = StopWatch() # Make sure that input is valid @@ -970,7 +971,7 @@ def fit( self._logger.exception(e) raise e finally: - progress_bar.stop() + progress_bar.join() self._fit_cleanup() self.fitted = True @@ -1920,15 +1921,17 @@ def cv_results_(self): metric_dict[metric.name] = [] metric_mask[metric.name] = [] + model_ids = [] mean_fit_time = [] params = [] status = [] budgets = [] - for run_key in self.runhistory_.data: - run_value = self.runhistory_.data[run_key] + for run_key, run_value in self.runhistory_.data.items(): config_id = run_key.config_id config = self.runhistory_.ids_config[config_id] + if run_value.additional_info and "num_run" in run_value.additional_info: + model_ids.append(run_value.additional_info["num_run"]) s = run_value.status if s == StatusType.SUCCESS: @@ -1989,6 +1992,8 @@ def cv_results_(self): metric_dict[metric.name].append(metric_value) metric_mask[metric.name].append(mask_value) + results["model_ids"] = model_ids + if len(self._metrics) == 1: results["mean_test_score"] = np.array(metric_dict[self._metrics[0].name]) rank_order = -1 * self._metrics[0]._sign * results["mean_test_score"] @@ -2164,14 +2169,11 @@ def show_models(self) -> dict[int, Any]: warnings.warn("No ensemble found. Returning empty dictionary.") return ensemble_dict - def has_key(rv, key): - return rv.additional_info and key in rv.additional_info - table_dict = {} - for run_key, run_val in self.runhistory_.data.items(): - if has_key(run_val, "num_run"): - model_id = run_val.additional_info["num_run"] - table_dict[model_id] = {"model_id": model_id, "cost": run_val.cost} + for run_key, run_value in self.runhistory_.data.items(): + if run_value.additional_info and "num_run" in run_value.additional_info: + model_id = run_value.additional_info["num_run"] + table_dict[model_id] = {"model_id": model_id, "cost": run_value.cost} # Checking if the dictionary is empty if not table_dict: diff --git a/autosklearn/util/progress_bar.py b/autosklearn/util/progress_bar.py index 7ccd3bc153..c1eb3139f8 100644 --- a/autosklearn/util/progress_bar.py +++ b/autosklearn/util/progress_bar.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Any import datetime @@ -10,22 +12,45 @@ class ProgressBar(Thread): """A Thread that displays a tqdm progress bar in the console. - It is specialized to display information relevant to fitting to the training data - with auto-sklearn. + Treat this class as an ordinary thread. So to display a progress bar, + call start() on an instance of this class. To wait for the thread to + terminate call join(), which will max out the progress bar, + therefore terminate this thread immediately. Parameters ---------- total : int - The total amount that should be reached by the progress bar once it finishes - update_interval : float - Specifies how frequently the progress bar is updated (in seconds) - disable : bool - Turns on or off the progress bar. If True, this thread won't be started or - initialized. - kwargs : Any + The total amount that should be reached by the progress bar once it finishes. + update_interval : float, default=1.0 + Specifies how frequently the progress bar is updated (in seconds). + disable : bool, default=False + Turns on or off the progress bar. If True, this thread does not get + initialized and won't be started if start() is called. + tqdm_kwargs : Any, optional Keyword arguments that are passed into tqdm's constructor. Refer to: - `tqdm `_. Note that postfix can not be - specified in the kwargs since it is already passed into tqdm by this class. + `tqdm `_ for a list of parameters that + tqdm accepts. Note that 'postfix' cannot be specified in the kwargs since it is + already passed into tqdm by this class. + + Examples + -------- + + .. code:: python + + progress_bar = ProgressBar( + total=10, + desc="Executing code that runs for 10 seconds", + colour="green", + ) + # colour is a tqdm parameter passed as a tqdm_kwargs + try: + progress_bar.start() + # some code that runs for 10 seconds + except SomeException: + # something went wrong + finally: + progress_bar.join() + # perform some cleanup """ def __init__( @@ -33,7 +58,7 @@ def __init__( total: int, update_interval: float = 1.0, disable: bool = False, - **kwargs: Any, + **tqdm_kwargs: Any, ): self.disable = disable if not disable: @@ -41,28 +66,27 @@ def __init__( self.total = total self.update_interval = update_interval self.terminated: bool = False - self.kwargs = kwargs - # start this thread - self.start() + self.tqdm_kwargs = tqdm_kwargs - def run(self) -> None: - """Display a tqdm progress bar in the console. + def start(self) -> None: + """Start a new thread that calls the run() method.""" + if not self.disable: + super().start() - Additionally, it shows useful information related to the task. This method - overrides the run method of Thread. - """ + def run(self) -> None: + """Display a tqdm progress bar in the console.""" if not self.disable: for _ in trange( self.total, postfix=f"The total time budget for this task is " f"{datetime.timedelta(seconds=self.total)}", - **self.kwargs, + **self.tqdm_kwargs, ): if not self.terminated: time.sleep(self.update_interval) - def stop(self) -> None: - """Terminates the thread.""" + def join(self, timeout: float | None = None) -> None: + """Maxes out the progress bar and thereby terminating this thread.""" if not self.disable: self.terminated = True - super().join() + super().join(timeout)