Skip to content

Commit

Permalink
Merge pull request mmistakes#71 from pesser/misc
Browse files Browse the repository at this point in the history
Misc
  • Loading branch information
theRealSuperMario authored Jun 24, 2019
2 parents d7e1e63 + f0ea94c commit fec4e39
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 3 deletions.
4 changes: 4 additions & 0 deletions edflow/iterators/tf_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def __init__(self, *args, desc="Eval", hook_freq=1, num_epochs=1, **kwargs):
"""
kwargs.update({"desc": desc, "hook_freq": hook_freq, "num_epochs": num_epochs})
super().__init__(*args, **kwargs)
self.define_graph()

def initialize(self, checkpoint_path=None):
self.restore_variables = self.model.variables
Expand All @@ -101,5 +102,8 @@ def initialize(self, checkpoint_path=None):
)
self.hooks += [waiter]

def define_graph(self):
pass

def step_ops(self):
return self.model.outputs
29 changes: 26 additions & 3 deletions edflow/iterators/tf_trainer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import tensorflow as tf
import numpy as np

from edflow.iterators.tf_iterator import HookedModelIterator, TFHookedModelIterator

Expand Down Expand Up @@ -233,7 +234,7 @@ def setup(self):
)
ihook = IntervalHook(
[loghook],
interval=1,
interval=self.config.get("start_log_freq", 1),
modify_each=1,
max_interval=self.config.get("log_freq", 1000),
get_step=self.get_global_step,
Expand Down Expand Up @@ -293,7 +294,18 @@ def create_train_op(self):
opt_ops[k] = self.optimizers[k].minimize(losses[k], var_list=variables)
print(k)
print("============================")
print("\n".join([v.name for v in variables]))
print(
"\n".join(
[
"{:>22} {:>22} {}".format(
str(v.shape.as_list()),
str(np.prod(v.shape.as_list())),
v.name,
)
for v in variables
]
)
)
print(len(variables))
print("============================")
self.opt_ops = opt_ops
Expand Down Expand Up @@ -418,7 +430,18 @@ def create_train_op(self):
opt_ops[k] = optimizers[k].minimize(losses[k], var_list=variables)
print(i, k, self.get_learning_rate_multiplier(i))
print("============================")
print("\n".join([v.name for v in variables]))
print(
"\n".join(
[
"{:>22} {:>22} {}".format(
str(v.shape.as_list()),
str(np.prod(v.shape.as_list())),
v.name,
)
for v in variables
]
)
)
print(len(variables))
print("============================")

Expand Down

0 comments on commit fec4e39

Please sign in to comment.