From a66ef9954bb9b832edf963d0ee2a77312c0d73f0 Mon Sep 17 00:00:00 2001 From: matt Date: Thu, 23 Jun 2022 14:30:29 +0100 Subject: [PATCH 1/3] Fix tests that broke when models used batchnorm --- tests/test_modeling_tf_common.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 843ddaa5e3ff..62de6e14a5b7 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -1383,7 +1383,7 @@ def test_keras_fit(self): else: metrics = [] - model.compile(optimizer=tf.keras.optimizers.SGD(0.0), run_eagerly=True, metrics=metrics) + model.compile(optimizer=tf.keras.optimizers.SGD(0.0), metrics=metrics) # Make sure the model fits without crashing regardless of where we pass the labels history1 = model.fit( prepared_for_class, @@ -1394,6 +1394,12 @@ def test_keras_fit(self): ) val_loss1 = history1.history["val_loss"][0] accuracy1 = {key: val[0] for key, val in history1.history.items() if key.endswith("accuracy")} + + # We reinitialize the model here even though our learning rate was zero + # because BatchNorm updates weights by means other than gradient descent. + model = model_class(config) + model.compile(optimizer=tf.keras.optimizers.SGD(0.0), metrics=metrics) + history2 = model.fit( inputs_minus_labels, labels, @@ -1416,6 +1422,11 @@ def test_keras_fit(self): dataset = tf.data.Dataset.from_tensor_slices(prepared_for_class) # Pass in all samples as a batch to match other `fit` calls dataset = dataset.batch(len(dataset)) + + # Reinitialize to fix batchnorm again + model = model_class(config) + model.compile(optimizer=tf.keras.optimizers.SGD(0.0), metrics=metrics) + history3 = model.fit( dataset, validation_data=dataset, From d59814135c93e6d1a26fd01b7f7fc2d531e12c7e Mon Sep 17 00:00:00 2001 From: matt Date: Thu, 23 Jun 2022 14:55:55 +0100 Subject: [PATCH 2/3] Initializing the model twice does not actually... ...give you the same weights each time. I am good at machine learning. --- tests/test_modeling_tf_common.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 62de6e14a5b7..889faf1f6a66 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -1383,6 +1383,9 @@ def test_keras_fit(self): else: metrics = [] + model(model.dummy_inputs) # Build the model so we can get some constant weights + model_weights = model.get_weights() + model.compile(optimizer=tf.keras.optimizers.SGD(0.0), metrics=metrics) # Make sure the model fits without crashing regardless of where we pass the labels history1 = model.fit( @@ -1397,8 +1400,7 @@ def test_keras_fit(self): # We reinitialize the model here even though our learning rate was zero # because BatchNorm updates weights by means other than gradient descent. - model = model_class(config) - model.compile(optimizer=tf.keras.optimizers.SGD(0.0), metrics=metrics) + model.set_weights(model_weights) history2 = model.fit( inputs_minus_labels, @@ -1409,7 +1411,7 @@ def test_keras_fit(self): shuffle=False, ) val_loss2 = history2.history["val_loss"][0] - accuracy2 = {key: val[0] for key, val in history1.history.items() if key.endswith("accuracy")} + accuracy2 = {key: val[0] for key, val in history2.history.items() if key.endswith("accuracy")} self.assertTrue(np.allclose(val_loss1, val_loss2, atol=1e-2, rtol=1e-3)) self.assertEqual(history1.history.keys(), history2.history.keys()) for key in history1.history.keys(): @@ -1424,8 +1426,7 @@ def test_keras_fit(self): dataset = dataset.batch(len(dataset)) # Reinitialize to fix batchnorm again - model = model_class(config) - model.compile(optimizer=tf.keras.optimizers.SGD(0.0), metrics=metrics) + model.set_weights(model_weights) history3 = model.fit( dataset, From 56bb4c95ee1348f04532d7ebfe5ec8261c5c24ba Mon Sep 17 00:00:00 2001 From: matt Date: Thu, 23 Jun 2022 14:57:39 +0100 Subject: [PATCH 3/3] Fix speed regression --- tests/test_modeling_tf_common.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 889faf1f6a66..1d0997252018 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -1386,7 +1386,8 @@ def test_keras_fit(self): model(model.dummy_inputs) # Build the model so we can get some constant weights model_weights = model.get_weights() - model.compile(optimizer=tf.keras.optimizers.SGD(0.0), metrics=metrics) + # Run eagerly to save some expensive compilation times + model.compile(optimizer=tf.keras.optimizers.SGD(0.0), run_eagerly=True, metrics=metrics) # Make sure the model fits without crashing regardless of where we pass the labels history1 = model.fit( prepared_for_class,