From c4bd2a963603e91f6039febd6c2a5e2949f216e9 Mon Sep 17 00:00:00 2001 From: Vincent Roseberry Date: Thu, 4 Feb 2021 09:23:19 -0800 Subject: [PATCH] Add pydot package (#956) This is used by keras plot_model() function. Included a test to prevent regression. Improved the existing tf test with extra assertions. The issue was raised by a user here: https://www.kaggle.com/c/jane-street-market-prediction/discussion/214494#1184233 --- Dockerfile | 1 + tests/test_tensorflow.py | 13 ++++++++++--- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/Dockerfile b/Dockerfile index 9d487c73..f7080c5b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -80,6 +80,7 @@ RUN apt-get install -y libfreetype6-dev && \ pip install xgboost && \ # Pinned to match GPU version. Update version together. pip install lightgbm==3.1.1 && \ + pip install pydot && \ pip install keras && \ pip install keras-tuner && \ pip install flake8 && \ diff --git a/tests/test_tensorflow.py b/tests/test_tensorflow.py index 938a88f3..8b415e25 100644 --- a/tests/test_tensorflow.py +++ b/tests/test_tensorflow.py @@ -1,4 +1,5 @@ import unittest +import os.path import numpy as np import tensorflow as tf @@ -37,8 +38,13 @@ def test_tf_keras(self): metrics=['accuracy']) model.fit(x_train, y_train, epochs=1) - model.evaluate(x_test, y_test) - + + result = model.evaluate(x_test, y_test) + self.assertEqual(2, len(result)) + + # exercices pydot path. + tf.keras.utils.plot_model(model, to_file="tf_plot_model.png") + self.assertTrue(os.path.isfile("tf_plot_model.png")) def test_lstm(self): x_train = np.random.random((100, 28, 28)) @@ -58,7 +64,8 @@ def test_lstm(self): metrics=['accuracy']) model.fit(x_train, y_train, epochs=1) - model.evaluate(x_test, y_test) + result = model.evaluate(x_test, y_test) + self.assertEqual(2, len(result)) @gpu_test def test_gpu(self):