diff --git a/hpobench/benchmarks/ml/tabular_benchmark.py b/hpobench/benchmarks/ml/tabular_benchmark.py index 72e5fb31..c5525bf5 100644 --- a/hpobench/benchmarks/ml/tabular_benchmark.py +++ b/hpobench/benchmarks/ml/tabular_benchmark.py @@ -163,7 +163,7 @@ def _objective( metric_str = ', '.join(list(metrics.keys())) assert metric in list(metrics.keys()), f"metric not found among: {metric_str}" score_key = f"{evaluation}_scores" - cost_key = f"{evaluation}_scores" + cost_key = f"{evaluation}_costs" key_path = dict() for name in self.configuration_space.get_hyperparameter_names(): diff --git a/hpobench/util/data_manager.py b/hpobench/util/data_manager.py index a2e33121..479244f9 100644 --- a/hpobench/util/data_manager.py +++ b/hpobench/util/data_manager.py @@ -41,6 +41,15 @@ import hpobench + +tabular_multi_fidelity_urls = dict( + xgb="https://figshare.com/ndownloader/files/30469920", + svm="https://figshare.com/ndownloader/files/30379359", + lr="https://figshare.com/ndownloader/files/30379038", + rf="https://figshare.com/ndownloader/files/30469089", + nn="https://figshare.com/ndownloader/files/30379005" +) + class DataManager(abc.ABC, metaclass=abc.ABCMeta): """ Base Class for loading and managing the data. @@ -929,21 +938,14 @@ def _load(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndar class TabularDataManager(DataManager): def __init__(self, model: str, task_id: [int, str], data_dir: [str, Path, None] = None): super(TabularDataManager, self).__init__() + + self.model = model + self.task_id = str(task_id) - url_dict = dict( - xgb="https://ndownloader.figshare.com/files/30469920", - svm="https://ndownloader.figshare.com/files/30379359", - lr="https://ndownloader.figshare.com/files/30379038", - rf="https://ndownloader.figshare.com/files/30469089", - nn="https://ndownloader.figshare.com/files/30379005" - ) - + url_dict = tabular_multi_fidelity_urls assert model in url_dict.keys(), \ f'Model has to be one of {list(url_dict.keys())} but was {model}' - self.model = model - self.task_id = str(task_id) - self.url_to_use = url_dict.get(model) if data_dir is None: