diff --git a/python/ray/tune/logger.py b/python/ray/tune/logger.py index 87c14b520562..c971a8783feb 100644 --- a/python/ray/tune/logger.py +++ b/python/ray/tune/logger.py @@ -17,15 +17,8 @@ logger = logging.getLogger(__name__) -try: - import tensorflow as tf - use_tf150_api = (distutils.version.LooseVersion(tf.VERSION) >= - distutils.version.LooseVersion("1.5.0")) -except ImportError: - tf = None - use_tf150_api = True - logger.warning("Couldn't import TensorFlow - " - "disabling TensorBoard logging.") +tf = None +use_tf150_api = True class Logger(object): @@ -190,6 +183,15 @@ def to_tf_values(result, path): class _TFLogger(Logger): def _init(self): + try: + global tf, use_tf150_api + import tensorflow + tf = tensorflow + use_tf150_api = (distutils.version.LooseVersion(tf.VERSION) >= + distutils.version.LooseVersion("1.5.0")) + except ImportError: + logger.warning("Couldn't import TensorFlow - " + "disabling TensorBoard logging.") self._file_writer = tf.summary.FileWriter(self.logdir) def on_result(self, result): diff --git a/python/ray/tune/scripts.py b/python/ray/tune/scripts.py index 72db1453af19..c33cad9068b8 100644 --- a/python/ray/tune/scripts.py +++ b/python/ray/tune/scripts.py @@ -8,10 +8,9 @@ import os from datetime import datetime - import pandas as pd - from ray.tune.trial import Trial +from tabulate import tabulate def _flatten_dict(dt): @@ -34,13 +33,8 @@ def cli(): pass -DEFAULT_EXPERIMENT_INFO_KEYS = ( - "trial_name", - "trial_id", - "status", - "num_failures", - "logdir" -) +DEFAULT_EXPERIMENT_INFO_KEYS = ("trial_name", "trial_id", "status", + "num_failures", "logdir") DEFAULT_PROJECT_INFO_KEYS = ( "name", @@ -52,7 +46,8 @@ def cli(): ) -def _list_trials(experiment_path, info_keys=DEFAULT_EXPERIMENT_INFO_KEYS): +def _list_trials(experiment_path, sort, + info_keys=DEFAULT_EXPERIMENT_INFO_KEYS): experiment_path = os.path.expanduser(experiment_path) globs = glob.glob(os.path.join(experiment_path, "experiment_state*.json")) filename = max(list(globs)) @@ -60,30 +55,38 @@ def _list_trials(experiment_path, info_keys=DEFAULT_EXPERIMENT_INFO_KEYS): with open(filename) as f: experiment_state = json.load(f) - checkpoints = pd.DataFrame.from_records(experiment_state['checkpoints']) + checkpoints_df = pd.DataFrame( + experiment_state["checkpoints"])[list(info_keys)] + if "logdir" in checkpoints_df.columns: + checkpoints_df["logdir"] = checkpoints_df["logdir"].str.replace( + experiment_path, '') + if sort: + checkpoints_df = checkpoints_df.sort_values(by=sort) + print(tabulate(checkpoints_df, headers="keys", tablefmt="psql")) + # TODO(hartikainen): The logdir is often too verbose to be viewed in a # table. - checkpoints['logdir'] = checkpoints['logdir'].str.replace( - experiment_path, '') - - print(checkpoints[list(info_keys)].to_string()) + # checkpoints = pd.DataFrame.from_records(experiment_state['checkpoints']) + # checkpoints['logdir'] = checkpoints['logdir'].str.replace( + # experiment_path, '') + # print(checkpoints[list(info_keys)].to_string()) @cli.command() @click.argument("experiment_path", required=True, type=str) -def list_trials(experiment_path): - _list_trials(experiment_path) +@click.option( + '--sort', default=None, type=str, help='Select which column to sort on.') +def list_trials(experiment_path, sort): + _list_trials(experiment_path, sort) -def _list_experiments(project_path, info_keys=DEFAULT_PROJECT_INFO_KEYS): +def _list_experiments(project_path, sort, info_keys=DEFAULT_PROJECT_INFO_KEYS): base, experiment_paths, _ = list(os.walk(project_path))[0] # clean this experiment_data_collection = [] for experiment_path in experiment_paths: - experiment_state_path = glob.glob(os.path.join( - base, - experiment_path, - "experiment_state*.json")) + experiment_state_path = glob.glob( + os.path.join(base, experiment_path, "experiment_state*.json")) if not experiment_state_path: # TODO(hartikainen): Print some warning? @@ -94,29 +97,34 @@ def _list_experiments(project_path, info_keys=DEFAULT_PROJECT_INFO_KEYS): checkpoints = pd.DataFrame(experiment_state["checkpoints"]) runner_data = experiment_state["runner_data"] - timestamp = experiment_state["timestamp"] + timestamp = experiment_state.get("timestamp") experiment_data = { "name": experiment_path, - "start_time": runner_data["_start_time"], - "timestamp": datetime.fromtimestamp(timestamp), + "start_time": runner_data.get("_start_time"), + "timestamp": datetime.fromtimestamp(timestamp) + if timestamp else None, "total_trials": checkpoints.shape[0], "running_trials": (checkpoints["status"] == Trial.RUNNING).sum(), "terminated_trials": ( checkpoints["status"] == Trial.TERMINATED).sum(), "error_trials": (checkpoints["status"] == Trial.ERROR).sum(), - } + } experiment_data_collection.append(experiment_data) - info_dataframe = pd.DataFrame(experiment_data_collection) - print(info_dataframe[list(info_keys)].to_string()) + info_df = pd.DataFrame(experiment_data_collection) + if sort: + info_df = info_df.sort_values(by=sort) + print(tabulate(info_df, headers="keys", tablefmt="psql")) @cli.command() @click.argument("project_path", required=True, type=str) -def list_experiments(project_path): - _list_experiments(project_path) +@click.option( + '--sort', default=None, type=str, help='Select which column to sort on.') +def list_experiments(project_path, sort): + _list_experiments(project_path, sort) cli.add_command(list_trials, name="ls") diff --git a/python/ray/tune/suggest/bayesopt.py b/python/ray/tune/suggest/bayesopt.py index 98602ad52e7f..d67514d8cd7f 100644 --- a/python/ray/tune/suggest/bayesopt.py +++ b/python/ray/tune/suggest/bayesopt.py @@ -4,13 +4,16 @@ import copy -try: - import bayes_opt as byo -except Exception: - byo = None - from ray.tune.suggest.suggestion import SuggestionAlgorithm +byo = None + + +def _import_bayesopt(): + global byo + import bayes_opt + byo = bayes_opt + class BayesOptSearch(SuggestionAlgorithm): """A wrapper around BayesOpt to provide trial suggestions. diff --git a/python/ray/tune/suggest/hyperopt.py b/python/ray/tune/suggest/hyperopt.py index 2c32562505f9..a68eece2b4f9 100644 --- a/python/ray/tune/suggest/hyperopt.py +++ b/python/ray/tune/suggest/hyperopt.py @@ -6,17 +6,19 @@ import copy import logging -try: - hyperopt_logger = logging.getLogger("hyperopt") - hyperopt_logger.setLevel(logging.WARNING) - import hyperopt as hpo - from hyperopt.fmin import generate_trials_to_calculate -except Exception: - hpo = None - from ray.tune.error import TuneError from ray.tune.suggest.suggestion import SuggestionAlgorithm +hpo = None + + +def _import_hyperopt(): + global hpo + hyperopt_logger = logging.getLogger("hyperopt") + hyperopt_logger.setLevel(logging.WARNING) + import hyperopt + hpo = hyperopt + class HyperOptSearch(SuggestionAlgorithm): """A wrapper around HyperOpt to provide trial suggestions. @@ -73,6 +75,7 @@ def __init__(self, reward_attr="episode_reward_mean", points_to_evaluate=None, **kwargs): + _import_hyperopt() assert hpo is not None, "HyperOpt must be installed!" assert type(max_concurrent) is int and max_concurrent > 0 self._max_concurrent = max_concurrent @@ -84,7 +87,7 @@ def __init__(self, self._points_to_evaluate = 0 else: assert type(points_to_evaluate) == list - self._hpopt_trials = generate_trials_to_calculate( + self._hpopt_trials = hpo.fmin.generate_trials_to_calculate( points_to_evaluate) self._hpopt_trials.refresh() self._points_to_evaluate = len(points_to_evaluate) diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index f97b4711eac5..0a97bea4c284 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -547,8 +547,12 @@ def __getstate__(self): """ state = self.__dict__.copy() for k in [ - "_trials", "_stop_queue", "_server", "_search_alg", - "_scheduler_alg", "trial_executor", + "_trials", + "_stop_queue", + "_server", + "_search_alg", + "_scheduler_alg", + "trial_executor", ]: del state[k] state["_start_time"] = self._start_time.timestamp() @@ -561,7 +565,8 @@ def __setstate__(self, state): session = state.pop("_session") self.__dict__.setdefault("_session", session) start_time = state.pop("_start_time") - self.__dict__.setdefault("_start_time", datetime.fromtimestamp(start_time)) + self.__dict__.setdefault("_start_time", + datetime.fromtimestamp(start_time)) self.__dict__.update(state) if launch_web_server: diff --git a/python/setup.py b/python/setup.py index 95ba3cbf0bb3..7b8fc1118661 100644 --- a/python/setup.py +++ b/python/setup.py @@ -169,8 +169,7 @@ def find_version(*filepath): entry_points={ "console_scripts": [ "ray=ray.scripts.scripts:main", - "rllib=ray.rllib.scripts:cli [rllib]", - "tune=ray.tune.scripts:cli" + "rllib=ray.rllib.scripts:cli [rllib]", "tune=ray.tune.scripts:cli" ] }, include_package_data=True,