diff --git a/tensor2tensor/models/image_transformer_2d_test.py b/tensor2tensor/models/image_transformer_2d_test.py index 7f903fb15..7deddc870 100644 --- a/tensor2tensor/models/image_transformer_2d_test.py +++ b/tensor2tensor/models/image_transformer_2d_test.py @@ -35,8 +35,8 @@ def _test_img2img_transformer(self, net): hparams = image_transformer_2d.img2img_transformer2d_tiny() hparams.data_dir = "" p_hparams = registry.problem("image_celeba").get_hparams(hparams) - inputs = np.random.randint(256, size=(3, 4, 4, 3)) - targets = np.random.randint(256, size=(3, 8, 8, 3)) + inputs = np.random.randint(256, size=(batch_size, 4, 4, 3)) + targets = np.random.randint(256, size=(batch_size, 8, 8, 3)) with self.test_session() as session: features = { "inputs": tf.constant(inputs, dtype=tf.int32),