diff --git a/python/ray/tune/schedulers/median_stopping_rule.py b/python/ray/tune/schedulers/median_stopping_rule.py index 33ff0d5135cd..c34aa9960d90 100644 --- a/python/ray/tune/schedulers/median_stopping_rule.py +++ b/python/ray/tune/schedulers/median_stopping_rule.py @@ -60,15 +60,15 @@ def on_trial_result(self, trial_runner, trial, result): value by step `t` is strictly worse than the median of the running averages of all completed trials' objectives reported up to step `t`. """ - - if trial in self._stopped_trials: + trial_id = trial.trial_id + if trial_id in self._stopped_trials: assert not self._hard_stop return TrialScheduler.CONTINUE # fall back to FIFO time = result[self._time_attr] - self._results[trial].append(result) + self._results[trial_id].append(result) median_result = self._get_median_result(time) - best_result = self._best_result(trial) + best_result = self._best_result(trial_id) if self._verbose: logger.info("Trial {} best res={} vs median res={} at t={}".format( trial, best_result, median_result, time)) @@ -76,7 +76,7 @@ def on_trial_result(self, trial_runner, trial, result): if self._verbose: logger.info("MedianStoppingRule: " "early stopping {}".format(trial)) - self._stopped_trials.add(trial) + self._stopped_trials.add(trial_id) if self._hard_stop: return TrialScheduler.STOP else: @@ -85,13 +85,13 @@ def on_trial_result(self, trial_runner, trial, result): return TrialScheduler.CONTINUE def on_trial_complete(self, trial_runner, trial, result): - self._results[trial].append(result) - self._completed_trials.add(trial) + self._results[trial.trial_id].append(result) + self._completed_trials.add(trial.trial_id) def on_trial_remove(self, trial_runner, trial): """Marks trial as completed if it is paused and has previously ran.""" - if trial.status is Trial.PAUSED and trial in self._results: - self._completed_trials.add(trial) + if trial.status is Trial.PAUSED and trial.trial_id in self._results: + self._completed_trials.add(trial.trial_id) def debug_string(self): return "Using MedianStoppingRule: num_stopped={}.".format( @@ -99,15 +99,15 @@ def debug_string(self): def _get_median_result(self, time): scores = [] - for trial in self._completed_trials: - scores.append(self._running_result(trial, time)) + for trial_id in self._completed_trials: + scores.append(self._running_result(trial_id, time)) if len(scores) >= self._min_samples_required: return np.median(scores) else: return float('-inf') - def _running_result(self, trial, t_max=float('inf')): - results = self._results[trial] + def _running_result(self, trial_id, t_max=float('inf')): + results = self._results[trial_id] # TODO(ekl) we could do interpolation to be more precise, but for now # assume len(results) is large and the time diffs are roughly equal return np.mean([ @@ -115,6 +115,6 @@ def _running_result(self, trial, t_max=float('inf')): if r[self._time_attr] <= t_max ]) - def _best_result(self, trial): - results = self._results[trial] + def _best_result(self, trial_id): + results = self._results[trial_id] return max(r[self._reward_attr] for r in results) diff --git a/python/ray/tune/schedulers/trial_scheduler.py b/python/ray/tune/schedulers/trial_scheduler.py index 15fa3cb4cdc6..d816f013c19e 100644 --- a/python/ray/tune/schedulers/trial_scheduler.py +++ b/python/ray/tune/schedulers/trial_scheduler.py @@ -2,6 +2,8 @@ from __future__ import division from __future__ import print_function +import copy + from ray.tune.trial import Trial @@ -64,6 +66,12 @@ def debug_string(self): raise NotImplementedError + def __getstate__(self): + raise NotImplementedError + + def __setstate__(self, state): + raise NotImplementedError + class FIFOScheduler(TrialScheduler): """Simple scheduler that just runs trials in submission order.""" @@ -96,3 +104,9 @@ def choose_trial_to_run(self, trial_runner): def debug_string(self): return "Using FIFO scheduling algorithm." + + def __getstate__(self): + return copy.deepcopy(self.__dict__) + + def __setstate__(self, state): + self.__dict__.update(state)