Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor usage of Display class into Oracle #959

Merged
merged 7 commits into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@
## Bug fixes
* When running in parallel, the client oracle used to wait forever when the
chief oracle is not responding. Now, it is fixed.
* When running in parallel, the client would call the chief after calling
`oracle.end_trial()`, when the chief have already ended. Now, it is fixed.
* When running in parallel, the chief used to start to block in
`tuner.__init__()`. However, it makes more sense to block when calling
`tuner.search()`. Now, it is fixed.

# Release v1.4.4

Expand Down
3 changes: 1 addition & 2 deletions keras_tuner/distribute/oracle_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def build_model(hp):
max_trials=10,
directory=tmp_path,
)
tuner.search(x, y, validation_data=(x, y), epochs=1, batch_size=2)

# Only worker makes it to this point, server runs until thread stops.
assert dist_utils.has_chief_oracle()
Expand All @@ -118,8 +119,6 @@ def build_model(hp):
tuner.oracle, keras_tuner.distribute.oracle_client.OracleClient
)

tuner.search(x, y, validation_data=(x, y), epochs=1, batch_size=2)

# Suppress warnings about optimizer state not being restored by
# tf.keras.

Expand Down
35 changes: 18 additions & 17 deletions keras_tuner/engine/base_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,24 +132,13 @@ def __init__(
self._populate_initial_space()

# Run in distributed mode.
if dist_utils.is_chief_oracle():
# Blocks forever.
# Avoid import at the top, to avoid inconsistent protobuf versions.
from keras_tuner.distribute import oracle_chief

oracle_chief.start_server(self.oracle)
elif dist_utils.has_chief_oracle():
if dist_utils.has_chief_oracle() and not dist_utils.is_chief_oracle():
# Proxies requests to the chief oracle.
# Avoid import at the top, to avoid inconsistent protobuf versions.
from keras_tuner.distribute import oracle_client

self.oracle = oracle_client.OracleClient(self.oracle)

# In parallel tuning, everything below in __init__() is for workers
# only.
# Logs etc
self._display = tuner_utils.Display(oracle=self.oracle)

def _activate_all_conditions(self):
# Lists of stacks of conditions used during `explore_space()`.
scopes_never_active = []
Expand Down Expand Up @@ -211,10 +200,24 @@ def search(self, *fit_args, **fit_kwargs):
**fit_kwargs: Keyword arguments that should be passed to
`run_trial`, for example the training and validation data.
"""
verbose = "auto"
if "verbose" in fit_kwargs:
verbose = fit_kwargs.get("verbose")
self._display.verbose = verbose

# Only set verbosity on chief or when not running in parallel.
if (
not dist_utils.has_chief_oracle()
or dist_utils.is_chief_oracle()
):
self.oracle.verbose = verbose

if dist_utils.is_chief_oracle():
# Blocks until all the trials are finished.
# Avoid import at the top, to avoid inconsistent protobuf versions.
from keras_tuner.distribute import oracle_chief

oracle_chief.start_server(self.oracle)
return

self.on_search_begin()
while True:
self.pre_create_trial()
Expand Down Expand Up @@ -324,7 +327,7 @@ def on_trial_begin(self, trial):
Args:
trial: A `Trial` instance.
"""
self._display.on_trial_begin(self.oracle.get_trial(trial.trial_id))
pass

def on_trial_end(self, trial):
"""Called at the end of a trial.
Expand All @@ -333,8 +336,6 @@ def on_trial_end(self, trial):
trial: A `Trial` instance.
"""
self.oracle.end_trial(trial)
# Display needs the updated trial scored by the Oracle.
self._display.on_trial_end(self.oracle.get_trial(trial.trial_id))
self.save()

def on_search_begin(self):
Expand Down
152 changes: 152 additions & 0 deletions keras_tuner/engine/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import random
import threading
import warnings
from datetime import datetime

import numpy as np

Expand Down Expand Up @@ -113,6 +114,138 @@ def wrapped_func(*args, **kwargs):
return wrapped_func


# TODO: Add more extensive display.
class Display(stateful.Stateful):
def __init__(self, oracle, verbose=1):
self.verbose = verbose
self.oracle = oracle
self.col_width = 18

# Start time for the overall search
self.search_start = None

# Start time of the trials
# {trial_id: start_time}
self.trial_start = {}
# Trial number of the trials, starting from #1.
# {trial_id: trial_number}
self.trial_number = {}

def get_state(self):
return {
"search_start": self.search_start.isoformat()
if self.search_start is not None
else self.search_start,
"trial_start": {
key: value.isoformat()
for key, value in self.trial_start.items()
},
"trial_number": self.trial_number,
}

def set_state(self, state):
self.search_start = (
datetime.fromisoformat(state["search_start"])
if state["search_start"] is not None
else state["search_start"]
)
self.trial_start = {
key: datetime.fromisoformat(value)
for key, value in state["trial_start"].items()
}

self.trial_number = state["trial_number"]

def on_trial_begin(self, trial):
if self.verbose < 1:
return

start_time = datetime.now()
self.trial_start[trial.trial_id] = start_time
if self.search_start is None:
self.search_start = start_time
current_number = len(self.oracle.trials)
self.trial_number[trial.trial_id] = current_number

print()
print(f"Search: Running Trial #{current_number}")
print()
self.show_hyperparameter_table(trial)
print()

def on_trial_end(self, trial):
if self.verbose < 1:
return

utils.try_clear()

time_taken_str = self.format_duration(
datetime.now() - self.trial_start[trial.trial_id]
)
print(
f"Trial {self.trial_number[trial.trial_id]} "
f"Complete [{time_taken_str}]"
)

if trial.score is not None:
print(f"{self.oracle.objective.name}: {trial.score}")

print()
best_trials = self.oracle.get_best_trials()
best_score = best_trials[0].score if len(best_trials) > 0 else None
print(f"Best {self.oracle.objective.name} So Far: {best_score}")

time_elapsed_str = self.format_duration(
datetime.now() - self.search_start
)
print(f"Total elapsed time: {time_elapsed_str}")

def show_hyperparameter_table(self, trial):
template = "{{0:{0}}}|{{1:{0}}}|{{2}}".format(self.col_width)
best_trials = self.oracle.get_best_trials()
best_trial = best_trials[0] if len(best_trials) > 0 else None
if trial.hyperparameters.values:
print(
template.format("Value", "Best Value So Far", "Hyperparameter")
)
for hp, value in trial.hyperparameters.values.items():
best_value = (
best_trial.hyperparameters.values.get(hp)
if best_trial
else "?"
)
print(
template.format(
self.format_value(value),
self.format_value(best_value),
hp,
)
)
else:
print("default configuration")

def format_value(self, val):
if isinstance(val, (int, float)) and not isinstance(val, bool):
return f"{val:.5g}"
val_str = str(val)
if len(val_str) > self.col_width:
val_str = f"{val_str[:self.col_width - 3]}..."
return val_str

def format_duration(self, d):
s = round(d.total_seconds())
d = s // 86400
s %= 86400
h = s // 3600
s %= 3600
m = s // 60
s %= 60

if d > 0:
return f"{d:d}d {h:02d}h {m:02d}m {s:02d}s"
return f"{h:02d}h {m:02d}m {s:02d}s"


@keras_tuner_export("keras_tuner.Oracle")
class Oracle(stateful.Stateful):
"""Implements a hyperparameter optimization algorithm.
Expand Down Expand Up @@ -234,6 +367,19 @@ def __init__(
self.max_retries_per_trial = max_retries_per_trial
self.max_consecutive_failed_trials = max_consecutive_failed_trials

# Print the logs to screen
self._display = Display(oracle=self)

@property
def verbose(self):
return self._display.verbose

@verbose.setter
def verbose(self, value):
if value == "auto":
value = 1
self._display.verbose = value

def _populate_space(self, trial_id):
warnings.warn(
"The `_populate_space` method is deprecated, "
Expand Down Expand Up @@ -309,6 +455,7 @@ def create_trial(self, tuner_id):
trial.status = trial_module.TrialStatus.RUNNING
self.ongoing_trials[tuner_id] = trial
self.save()
self._display.on_trial_begin(trial)
return trial

# Make the trial_id the current number of trial, pre-padded with 0s
Expand Down Expand Up @@ -342,6 +489,7 @@ def create_trial(self, tuner_id):
self.start_order.append(trial_id)
self._save_trial(trial)
self.save()
self._display.on_trial_begin(trial)

return trial

Expand Down Expand Up @@ -431,6 +579,8 @@ def end_trial(self, trial):
self._save_trial(trial)
self.save()

self._display.on_trial_end(trial)

# Pop the ongoing trial at last, which would notify the chief server to
# stop when ongoing_trials is empty.
for tuner_id, ongoing_trial in self.ongoing_trials.items():
Expand Down Expand Up @@ -547,6 +697,7 @@ def get_state(self):
"seed_state": self._seed_state,
"tried_so_far": list(self._tried_so_far),
"id_to_hash": self._id_to_hash,
"display": self._display.get_state(),
}

def set_state(self, state):
Expand All @@ -568,6 +719,7 @@ def set_state(self, state):
self._tried_so_far = set(state["tried_so_far"])
self._id_to_hash = collections.defaultdict(lambda: None)
self._id_to_hash.update(state["id_to_hash"])
self._display.set_state(state["display"])

def _set_project_dir(self, directory, project_name):
"""Sets the project directory and reloads the Oracle."""
Expand Down
10 changes: 10 additions & 0 deletions keras_tuner/engine/oracle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import datetime
import threading
import time

Expand All @@ -21,6 +22,7 @@
import keras_tuner
from keras_tuner.engine import oracle as oracle_module
from keras_tuner.engine import trial as trial_module
from keras_tuner.tuners import gridsearch


class OracleStub(oracle_module.Oracle):
Expand Down Expand Up @@ -440,3 +442,11 @@ def test_overwrite_false_resume(tmp_path):
assert (
oracle.get_trial(trial_id).status == trial_module.TrialStatus.COMPLETED
)


def test_display_format_duration_large_d():
oracle = gridsearch.GridSearchOracle()
d = datetime.datetime(2020, 5, 17) - datetime.datetime(2020, 5, 10)
oracle.verbose = "auto"
assert oracle_module.Display(oracle).format_duration(d) == "7d 00h 00m 00s"
assert oracle.verbose == 1
5 changes: 4 additions & 1 deletion keras_tuner/engine/tuner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1357,7 +1357,10 @@ def test_correct_display_trial_number(tmp_path):
new_tuner.search(
TRAIN_INPUTS, TRAIN_TARGETS, validation_data=(VAL_INPUTS, VAL_TARGETS)
)
assert len(new_tuner.oracle.trials) == new_tuner._display.trial_number
new_tuner.oracle._display.trial_number.items()
assert len(new_tuner.oracle.trials) == max(
new_tuner.oracle._display.trial_number.values()
)


def test_error_on_unknown_objective_direction(tmp_path):
Expand Down
Loading