Skip to content

Commit

Permalink
Refactor usage of Display class into Oracle (#959)
Browse files Browse the repository at this point in the history
  • Loading branch information
haifeng-jin authored Oct 10, 2023
1 parent 5c439ec commit 9b648f7
Show file tree
Hide file tree
Showing 8 changed files with 190 additions and 124 deletions.
5 changes: 5 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@
## Bug fixes
* When running in parallel, the client oracle used to wait forever when the
chief oracle is not responding. Now, it is fixed.
* When running in parallel, the client would call the chief after calling
`oracle.end_trial()`, when the chief have already ended. Now, it is fixed.
* When running in parallel, the chief used to start to block in
`tuner.__init__()`. However, it makes more sense to block when calling
`tuner.search()`. Now, it is fixed.

# Release v1.4.4

Expand Down
3 changes: 1 addition & 2 deletions keras_tuner/distribute/oracle_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def build_model(hp):
max_trials=10,
directory=tmp_path,
)
tuner.search(x, y, validation_data=(x, y), epochs=1, batch_size=2)

# Only worker makes it to this point, server runs until thread stops.
assert dist_utils.has_chief_oracle()
Expand All @@ -118,8 +119,6 @@ def build_model(hp):
tuner.oracle, keras_tuner.distribute.oracle_client.OracleClient
)

tuner.search(x, y, validation_data=(x, y), epochs=1, batch_size=2)

# Suppress warnings about optimizer state not being restored by
# tf.keras.

Expand Down
35 changes: 18 additions & 17 deletions keras_tuner/engine/base_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,24 +132,13 @@ def __init__(
self._populate_initial_space()

# Run in distributed mode.
if dist_utils.is_chief_oracle():
# Blocks forever.
# Avoid import at the top, to avoid inconsistent protobuf versions.
from keras_tuner.distribute import oracle_chief

oracle_chief.start_server(self.oracle)
elif dist_utils.has_chief_oracle():
if dist_utils.has_chief_oracle() and not dist_utils.is_chief_oracle():
# Proxies requests to the chief oracle.
# Avoid import at the top, to avoid inconsistent protobuf versions.
from keras_tuner.distribute import oracle_client

self.oracle = oracle_client.OracleClient(self.oracle)

# In parallel tuning, everything below in __init__() is for workers
# only.
# Logs etc
self._display = tuner_utils.Display(oracle=self.oracle)

def _activate_all_conditions(self):
# Lists of stacks of conditions used during `explore_space()`.
scopes_never_active = []
Expand Down Expand Up @@ -211,10 +200,24 @@ def search(self, *fit_args, **fit_kwargs):
**fit_kwargs: Keyword arguments that should be passed to
`run_trial`, for example the training and validation data.
"""
verbose = "auto"
if "verbose" in fit_kwargs:
verbose = fit_kwargs.get("verbose")
self._display.verbose = verbose

# Only set verbosity on chief or when not running in parallel.
if (
not dist_utils.has_chief_oracle()
or dist_utils.is_chief_oracle()
):
self.oracle.verbose = verbose

if dist_utils.is_chief_oracle():
# Blocks until all the trials are finished.
# Avoid import at the top, to avoid inconsistent protobuf versions.
from keras_tuner.distribute import oracle_chief

oracle_chief.start_server(self.oracle)
return

self.on_search_begin()
while True:
self.pre_create_trial()
Expand Down Expand Up @@ -324,7 +327,7 @@ def on_trial_begin(self, trial):
Args:
trial: A `Trial` instance.
"""
self._display.on_trial_begin(self.oracle.get_trial(trial.trial_id))
pass

def on_trial_end(self, trial):
"""Called at the end of a trial.
Expand All @@ -333,8 +336,6 @@ def on_trial_end(self, trial):
trial: A `Trial` instance.
"""
self.oracle.end_trial(trial)
# Display needs the updated trial scored by the Oracle.
self._display.on_trial_end(self.oracle.get_trial(trial.trial_id))
self.save()

def on_search_begin(self):
Expand Down
152 changes: 152 additions & 0 deletions keras_tuner/engine/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import random
import threading
import warnings
from datetime import datetime

import numpy as np

Expand Down Expand Up @@ -113,6 +114,138 @@ def wrapped_func(*args, **kwargs):
return wrapped_func


# TODO: Add more extensive display.
class Display(stateful.Stateful):
def __init__(self, oracle, verbose=1):
self.verbose = verbose
self.oracle = oracle
self.col_width = 18

# Start time for the overall search
self.search_start = None

# Start time of the trials
# {trial_id: start_time}
self.trial_start = {}
# Trial number of the trials, starting from #1.
# {trial_id: trial_number}
self.trial_number = {}

def get_state(self):
return {
"search_start": self.search_start.isoformat()
if self.search_start is not None
else self.search_start,
"trial_start": {
key: value.isoformat()
for key, value in self.trial_start.items()
},
"trial_number": self.trial_number,
}

def set_state(self, state):
self.search_start = (
datetime.fromisoformat(state["search_start"])
if state["search_start"] is not None
else state["search_start"]
)
self.trial_start = {
key: datetime.fromisoformat(value)
for key, value in state["trial_start"].items()
}

self.trial_number = state["trial_number"]

def on_trial_begin(self, trial):
if self.verbose < 1:
return

start_time = datetime.now()
self.trial_start[trial.trial_id] = start_time
if self.search_start is None:
self.search_start = start_time
current_number = len(self.oracle.trials)
self.trial_number[trial.trial_id] = current_number

print()
print(f"Search: Running Trial #{current_number}")
print()
self.show_hyperparameter_table(trial)
print()

def on_trial_end(self, trial):
if self.verbose < 1:
return

utils.try_clear()

time_taken_str = self.format_duration(
datetime.now() - self.trial_start[trial.trial_id]
)
print(
f"Trial {self.trial_number[trial.trial_id]} "
f"Complete [{time_taken_str}]"
)

if trial.score is not None:
print(f"{self.oracle.objective.name}: {trial.score}")

print()
best_trials = self.oracle.get_best_trials()
best_score = best_trials[0].score if len(best_trials) > 0 else None
print(f"Best {self.oracle.objective.name} So Far: {best_score}")

time_elapsed_str = self.format_duration(
datetime.now() - self.search_start
)
print(f"Total elapsed time: {time_elapsed_str}")

def show_hyperparameter_table(self, trial):
template = "{{0:{0}}}|{{1:{0}}}|{{2}}".format(self.col_width)
best_trials = self.oracle.get_best_trials()
best_trial = best_trials[0] if len(best_trials) > 0 else None
if trial.hyperparameters.values:
print(
template.format("Value", "Best Value So Far", "Hyperparameter")
)
for hp, value in trial.hyperparameters.values.items():
best_value = (
best_trial.hyperparameters.values.get(hp)
if best_trial
else "?"
)
print(
template.format(
self.format_value(value),
self.format_value(best_value),
hp,
)
)
else:
print("default configuration")

def format_value(self, val):
if isinstance(val, (int, float)) and not isinstance(val, bool):
return f"{val:.5g}"
val_str = str(val)
if len(val_str) > self.col_width:
val_str = f"{val_str[:self.col_width - 3]}..."
return val_str

def format_duration(self, d):
s = round(d.total_seconds())
d = s // 86400
s %= 86400
h = s // 3600
s %= 3600
m = s // 60
s %= 60

if d > 0:
return f"{d:d}d {h:02d}h {m:02d}m {s:02d}s"
return f"{h:02d}h {m:02d}m {s:02d}s"


@keras_tuner_export("keras_tuner.Oracle")
class Oracle(stateful.Stateful):
"""Implements a hyperparameter optimization algorithm.
Expand Down Expand Up @@ -234,6 +367,19 @@ def __init__(
self.max_retries_per_trial = max_retries_per_trial
self.max_consecutive_failed_trials = max_consecutive_failed_trials

# Print the logs to screen
self._display = Display(oracle=self)

@property
def verbose(self):
return self._display.verbose

@verbose.setter
def verbose(self, value):
if value == "auto":
value = 1
self._display.verbose = value

def _populate_space(self, trial_id):
warnings.warn(
"The `_populate_space` method is deprecated, "
Expand Down Expand Up @@ -309,6 +455,7 @@ def create_trial(self, tuner_id):
trial.status = trial_module.TrialStatus.RUNNING
self.ongoing_trials[tuner_id] = trial
self.save()
self._display.on_trial_begin(trial)
return trial

# Make the trial_id the current number of trial, pre-padded with 0s
Expand Down Expand Up @@ -342,6 +489,7 @@ def create_trial(self, tuner_id):
self.start_order.append(trial_id)
self._save_trial(trial)
self.save()
self._display.on_trial_begin(trial)

return trial

Expand Down Expand Up @@ -431,6 +579,8 @@ def end_trial(self, trial):
self._save_trial(trial)
self.save()

self._display.on_trial_end(trial)

# Pop the ongoing trial at last, which would notify the chief server to
# stop when ongoing_trials is empty.
for tuner_id, ongoing_trial in self.ongoing_trials.items():
Expand Down Expand Up @@ -547,6 +697,7 @@ def get_state(self):
"seed_state": self._seed_state,
"tried_so_far": list(self._tried_so_far),
"id_to_hash": self._id_to_hash,
"display": self._display.get_state(),
}

def set_state(self, state):
Expand All @@ -568,6 +719,7 @@ def set_state(self, state):
self._tried_so_far = set(state["tried_so_far"])
self._id_to_hash = collections.defaultdict(lambda: None)
self._id_to_hash.update(state["id_to_hash"])
self._display.set_state(state["display"])

def _set_project_dir(self, directory, project_name):
"""Sets the project directory and reloads the Oracle."""
Expand Down
10 changes: 10 additions & 0 deletions keras_tuner/engine/oracle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import datetime
import threading
import time

Expand All @@ -21,6 +22,7 @@
import keras_tuner
from keras_tuner.engine import oracle as oracle_module
from keras_tuner.engine import trial as trial_module
from keras_tuner.tuners import gridsearch


class OracleStub(oracle_module.Oracle):
Expand Down Expand Up @@ -440,3 +442,11 @@ def test_overwrite_false_resume(tmp_path):
assert (
oracle.get_trial(trial_id).status == trial_module.TrialStatus.COMPLETED
)


def test_display_format_duration_large_d():
oracle = gridsearch.GridSearchOracle()
d = datetime.datetime(2020, 5, 17) - datetime.datetime(2020, 5, 10)
oracle.verbose = "auto"
assert oracle_module.Display(oracle).format_duration(d) == "7d 00h 00m 00s"
assert oracle.verbose == 1
5 changes: 4 additions & 1 deletion keras_tuner/engine/tuner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1357,7 +1357,10 @@ def test_correct_display_trial_number(tmp_path):
new_tuner.search(
TRAIN_INPUTS, TRAIN_TARGETS, validation_data=(VAL_INPUTS, VAL_TARGETS)
)
assert len(new_tuner.oracle.trials) == new_tuner._display.trial_number
new_tuner.oracle._display.trial_number.items()
assert len(new_tuner.oracle.trials) == max(
new_tuner.oracle._display.trial_number.values()
)


def test_error_on_unknown_objective_direction(tmp_path):
Expand Down
Loading

0 comments on commit 9b648f7

Please sign in to comment.