Skip to content

Commit 158e3c1

Browse files
authored
hparams: fix demo (remove group_name argument) (#2258)
Summary: In #2188, I removed the `group_name` parameter from the Keras callback, but forgot to remove it from the call site in the demo code. Test Plan: Running `bazel run //tensorboard/plugins/hparams:hparams_demo` now works. Previously, it failed with the error, “unexpected keyword argument 'group_name'”. wchargin-branch: hparams-demo-no-group-name
1 parent 4de3936 commit 158e3c1

File tree

1 file changed

+2
-7
lines changed

1 file changed

+2
-7
lines changed

tensorboard/plugins/hparams/hparams_demo.py

+2-7
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from __future__ import division
2424
from __future__ import print_function
2525

26-
import hashlib
2726
import math
2827
import os.path
2928
import random
@@ -161,7 +160,7 @@ def model_fn(hparams, seed):
161160
return model
162161

163162

164-
def run(data, base_logdir, session_id, group_id, hparams):
163+
def run(data, base_logdir, session_id, hparams):
165164
"""Run a training/validation session.
166165
167166
Flags must have been parsed for this function to behave.
@@ -170,8 +169,6 @@ def run(data, base_logdir, session_id, group_id, hparams):
170169
data: The data as loaded by `prepare_data()`.
171170
base_logdir: The top-level logdir to which to write summary data.
172171
session_id: A unique string ID for this session.
173-
group_id: The string ID of the session group that includes this
174-
session.
175172
hparams: A dict mapping hyperparameters in `HPARAMS` to values.
176173
"""
177174
model = model_fn(hparams=hparams, seed=session_id)
@@ -182,7 +179,7 @@ def run(data, base_logdir, session_id, group_id, hparams):
182179
update_freq=flags.FLAGS.summary_freq,
183180
profile_batch=0, # workaround for issue #2084
184181
)
185-
hparams_callback = hp.KerasCallback(logdir, hparams, group_name=group_id)
182+
hparams_callback = hp.KerasCallback(logdir, hparams)
186183
((x_train, y_train), (x_test, y_test)) = data
187184
result = model.fit(
188185
x=x_train,
@@ -224,7 +221,6 @@ def run_all(logdir, verbose=False):
224221
for group_index in xrange(flags.FLAGS.num_session_groups):
225222
hparams = {h: sample_uniform(h.domain, rng) for h in HPARAMS}
226223
hparams_string = str(hparams)
227-
group_id = hashlib.sha256(hparams_string.encode("utf-8")).hexdigest()
228224
for repeat_index in xrange(sessions_per_group):
229225
session_id = str(session_index)
230226
session_index += 1
@@ -239,7 +235,6 @@ def run_all(logdir, verbose=False):
239235
data=data,
240236
base_logdir=logdir,
241237
session_id=session_id,
242-
group_id=group_id,
243238
hparams=hparams,
244239
)
245240

0 commit comments

Comments
 (0)