Skip to content

Commit ef77b81

Browse files
committed
fixed test for optimizer state in trained model history
1 parent 3fbc196 commit ef77b81

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

tests/training/test_train_gan.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,15 @@ def test_train(fp_gen, fp_disc, s_enhance, t_enhance, sample_shape, n_epoch=8):
110110

111111
assert np.allclose(model_params['optimizer']['learning_rate'], lr)
112112
assert np.allclose(model_params['optimizer_disc']['learning_rate'], lr)
113-
assert 'learning_rate_gen' in model.history
114-
assert 'learning_rate_disc' in model.history
113+
assert 'OptmGen/learning_rate' in model.history
114+
assert 'OptmDisc/learning_rate' in model.history
115+
116+
state_cols_b = ['OptmGen/Adam/m/conv3d/bias:0',
117+
'OptmGen/Adam/m/conv2d_transpose/bias:0']
118+
state_cols_k = ['OptmGen/Adam/v/conv3d/kernel:0',
119+
'OptmGen/Adam/v/conv2d_transpose/kernel:0']
120+
assert any(col in model.history for col in state_cols_b)
121+
assert any(col in model.history for col in state_cols_k)
115122

116123
assert 'config_generator' in loaded.meta
117124
assert 'config_discriminator' in loaded.meta

0 commit comments

Comments
 (0)