Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
haifeng-jin authored Oct 10, 2023
1 parent 2b07de2 commit 40cd10d
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
3 changes: 1 addition & 2 deletions keras_tuner/distribute/oracle_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def build_model(hp):
max_trials=10,
directory=tmp_path,
)
tuner.search(x, y, validation_data=(x, y), epochs=1, batch_size=2)

# Only worker makes it to this point, server runs until thread stops.
assert dist_utils.has_chief_oracle()
Expand All @@ -118,8 +119,6 @@ def build_model(hp):
tuner.oracle, keras_tuner.distribute.oracle_client.OracleClient
)

tuner.search(x, y, validation_data=(x, y), epochs=1, batch_size=2)

# Suppress warnings about optimizer state not being restored by
# tf.keras.

Expand Down
3 changes: 2 additions & 1 deletion keras_tuner/engine/base_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def __init__(
self._populate_initial_space()

# Run in distributed mode.
if dist_utils.has_chief_oracle():
if dist_utils.has_chief_oracle() and not dist_utils.is_chief_oracle():
# Proxies requests to the chief oracle.
# Avoid import at the top, to avoid inconsistent protobuf versions.
from keras_tuner.distribute import oracle_client
Expand Down Expand Up @@ -216,6 +216,7 @@ def search(self, *fit_args, **fit_kwargs):
from keras_tuner.distribute import oracle_chief

oracle_chief.start_server(self.oracle)
return

self.on_search_begin()
while True:
Expand Down

0 comments on commit 40cd10d

Please sign in to comment.