Skip to content

Commit 9c34e65

Browse files
authored
Merge pull request #242 from NREL/gb/optm_state
Gb/optm state
2 parents e43e4bb + a74ed5c commit 9c34e65

File tree

4 files changed

+55
-15
lines changed

4 files changed

+55
-15
lines changed

sup3r/models/abstract.py

+24-1
Original file line numberDiff line numberDiff line change
@@ -1037,7 +1037,7 @@ def get_optimizer_config(optimizer):
10371037
Parameters
10381038
----------
10391039
optimizer : tf.keras.optimizers.Optimizer
1040-
TF-Keras optimizer object
1040+
TF-Keras optimizer object (e.g., Adam)
10411041
10421042
Returns
10431043
-------
@@ -1053,6 +1053,29 @@ def get_optimizer_config(optimizer):
10531053
conf[k] = int(v)
10541054
return conf
10551055

1056+
@classmethod
1057+
def get_optimizer_state(cls, optimizer):
1058+
"""Get a set of state variables for the optimizer
1059+
1060+
Parameters
1061+
----------
1062+
optimizer : tf.keras.optimizers.Optimizer
1063+
TF-Keras optimizer object (e.g., Adam)
1064+
1065+
Returns
1066+
-------
1067+
state : dict
1068+
Optimizer state variables
1069+
"""
1070+
lr = cls.get_optimizer_config(optimizer)['learning_rate']
1071+
state = {'learning_rate': lr}
1072+
for var in optimizer.variables:
1073+
name = var.name
1074+
var = var.numpy().flatten()
1075+
var = np.abs(var).mean() # collapse ndarrays into mean absolute
1076+
state[name] = float(var)
1077+
return state
1078+
10561079
@staticmethod
10571080
def update_loss_details(loss_details, new_data, batch_len, prefix=None):
10581081
"""Update a dictionary of loss_details with loss information from a new

sup3r/models/base.py

+12-10
Original file line numberDiff line numberDiff line change
@@ -746,8 +746,10 @@ def train_epoch(
746746

747747
b_loss_details['gen_trained_frac'] = float(trained_gen)
748748
b_loss_details['disc_trained_frac'] = float(trained_disc)
749+
749750
self.dict_to_tensorboard(b_loss_details)
750751
self.dict_to_tensorboard(self.timer.log)
752+
751753
loss_details = self.update_loss_details(
752754
loss_details,
753755
b_loss_details,
@@ -1000,10 +1002,9 @@ def train(
10001002
loss_details['train_loss_gen'], loss_details['train_loss_disc']
10011003
)
10021004

1003-
if all(
1004-
loss in loss_details
1005-
for loss in ('val_loss_gen', 'val_loss_disc')
1006-
):
1005+
check1 = 'val_loss_gen' in loss_details
1006+
check2 = 'val_loss_disc' in loss_details
1007+
if check1 and check2:
10071008
msg += 'gen/disc val loss: {:.2e}/{:.2e} '.format(
10081009
loss_details['val_loss_gen'], loss_details['val_loss_disc']
10091010
)
@@ -1016,14 +1017,15 @@ def train(
10161017
'weight_gen_advers': weight_gen_advers,
10171018
'disc_loss_bound_0': disc_loss_bounds[0],
10181019
'disc_loss_bound_1': disc_loss_bounds[1],
1019-
'learning_rate_gen': self.get_optimizer_config(self.optimizer)[
1020-
'learning_rate'
1021-
],
1022-
'learning_rate_disc': self.get_optimizer_config(
1023-
self.optimizer_disc
1024-
)['learning_rate'],
10251020
}
10261021

1022+
opt_g = self.get_optimizer_state(self.optimizer)
1023+
opt_d = self.get_optimizer_state(self.optimizer_disc)
1024+
opt_g = {f'OptmGen/{key}': val for key, val in opt_g.items()}
1025+
opt_d = {f'OptmDisc/{key}': val for key, val in opt_d.items()}
1026+
extras.update(opt_g)
1027+
extras.update(opt_d)
1028+
10271029
weight_gen_advers = self.update_adversarial_weights(
10281030
loss_details,
10291031
adaptive_update_fraction,

sup3r/preprocessing/cachers/base.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from sup3r.preprocessing.base import Container
1818
from sup3r.preprocessing.names import Dimension
1919
from sup3r.preprocessing.utilities import _mem_check, log_args, _lowered
20-
from sup3r.utilities.utilities import safe_cast
20+
from sup3r.utilities.utilities import safe_cast, safe_serialize
2121
from rex.utilities.utilities import to_records_array
2222

2323
from .utilities import _check_for_cache
@@ -475,7 +475,16 @@ def write_netcdf(
475475
ncfile.variables[dset][:] = np.asarray(data_var.data)
476476

477477
for attr_name, attr_value in attrs.items():
478-
ncfile.setncattr(attr_name, safe_cast(attr_value))
478+
attr_value = safe_cast(attr_value)
479+
try:
480+
ncfile.setncattr(attr_name, attr_value)
481+
except Exception as e:
482+
msg = (f'Could not write {attr_name} as attribute, '
483+
f'serializing with json dumps, '
484+
f'received error: "{e}"')
485+
logger.warning(msg)
486+
warn(msg)
487+
ncfile.setncattr(attr_name, safe_serialize(attr_value))
479488

480489
for feature in features:
481490
cls.write_netcdf_chunks(

tests/training/test_train_gan.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,14 @@ def test_train(fp_gen, fp_disc, s_enhance, t_enhance, sample_shape, n_epoch=8):
110110

111111
assert np.allclose(model_params['optimizer']['learning_rate'], lr)
112112
assert np.allclose(model_params['optimizer_disc']['learning_rate'], lr)
113-
assert 'learning_rate_gen' in model.history
114-
assert 'learning_rate_disc' in model.history
113+
assert 'OptmGen/learning_rate' in model.history
114+
assert 'OptmDisc/learning_rate' in model.history
115+
116+
msg = ('Could not find OptmGen states in columns: '
117+
f'{sorted(model.history.columns)}')
118+
check = [col.startswith('OptmGen/Adam/v')
119+
for col in model.history.columns]
120+
assert any(check), msg
115121

116122
assert 'config_generator' in loaded.meta
117123
assert 'config_discriminator' in loaded.meta

0 commit comments

Comments
 (0)