diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 843ddaa5e3ff..1d0997252018 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -1383,6 +1383,10 @@ def test_keras_fit(self): else: metrics = [] + model(model.dummy_inputs) # Build the model so we can get some constant weights + model_weights = model.get_weights() + + # Run eagerly to save some expensive compilation times model.compile(optimizer=tf.keras.optimizers.SGD(0.0), run_eagerly=True, metrics=metrics) # Make sure the model fits without crashing regardless of where we pass the labels history1 = model.fit( @@ -1394,6 +1398,11 @@ def test_keras_fit(self): ) val_loss1 = history1.history["val_loss"][0] accuracy1 = {key: val[0] for key, val in history1.history.items() if key.endswith("accuracy")} + + # We reinitialize the model here even though our learning rate was zero + # because BatchNorm updates weights by means other than gradient descent. + model.set_weights(model_weights) + history2 = model.fit( inputs_minus_labels, labels, @@ -1403,7 +1412,7 @@ def test_keras_fit(self): shuffle=False, ) val_loss2 = history2.history["val_loss"][0] - accuracy2 = {key: val[0] for key, val in history1.history.items() if key.endswith("accuracy")} + accuracy2 = {key: val[0] for key, val in history2.history.items() if key.endswith("accuracy")} self.assertTrue(np.allclose(val_loss1, val_loss2, atol=1e-2, rtol=1e-3)) self.assertEqual(history1.history.keys(), history2.history.keys()) for key in history1.history.keys(): @@ -1416,6 +1425,10 @@ def test_keras_fit(self): dataset = tf.data.Dataset.from_tensor_slices(prepared_for_class) # Pass in all samples as a batch to match other `fit` calls dataset = dataset.batch(len(dataset)) + + # Reinitialize to fix batchnorm again + model.set_weights(model_weights) + history3 = model.fit( dataset, validation_data=dataset,