Skip to content

Commit

Permalink
Add pydot package (#956)
Browse files Browse the repository at this point in the history
This is used by keras plot_model() function.

Included a test to prevent regression.
Improved the existing tf test with extra assertions.

The issue was raised by a user here: https://www.kaggle.com/c/jane-street-market-prediction/discussion/214494#1184233
  • Loading branch information
rosbo authored Feb 4, 2021
1 parent ddc0c00 commit c4bd2a9
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
1 change: 1 addition & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ RUN apt-get install -y libfreetype6-dev && \
pip install xgboost && \
# Pinned to match GPU version. Update version together.
pip install lightgbm==3.1.1 && \
pip install pydot && \
pip install keras && \
pip install keras-tuner && \
pip install flake8 && \
Expand Down
13 changes: 10 additions & 3 deletions tests/test_tensorflow.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import unittest
import os.path

import numpy as np
import tensorflow as tf
Expand Down Expand Up @@ -37,8 +38,13 @@ def test_tf_keras(self):
metrics=['accuracy'])

model.fit(x_train, y_train, epochs=1)
model.evaluate(x_test, y_test)


result = model.evaluate(x_test, y_test)
self.assertEqual(2, len(result))

# exercices pydot path.
tf.keras.utils.plot_model(model, to_file="tf_plot_model.png")
self.assertTrue(os.path.isfile("tf_plot_model.png"))

def test_lstm(self):
x_train = np.random.random((100, 28, 28))
Expand All @@ -58,7 +64,8 @@ def test_lstm(self):
metrics=['accuracy'])

model.fit(x_train, y_train, epochs=1)
model.evaluate(x_test, y_test)
result = model.evaluate(x_test, y_test)
self.assertEqual(2, len(result))

@gpu_test
def test_gpu(self):
Expand Down

0 comments on commit c4bd2a9

Please sign in to comment.