Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
298 changes: 45 additions & 253 deletions lib/iris/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,8 @@

The primary class for this module is :class:`IrisTest`.

By default, this module sets the matplotlib backend to "agg". But when
this module is imported it checks ``sys.argv`` for the flag "-d". If
found, it is removed from ``sys.argv`` and the matplotlib backend is
switched to "tkagg" to allow the interactive visual inspection of
graphical test results.

"""

import codecs
import collections
from collections.abc import Mapping
import contextlib
Expand All @@ -38,40 +31,23 @@
import shutil
import subprocess
import sys
import threading
from typing import Dict, List
import unittest
from unittest import mock
import warnings
import xml.dom.minidom
import zlib

import filelock
import numpy as np
import numpy.ma as ma
import requests

import iris.config
import iris.cube
import iris.tests.graphics as graphics
import iris.util

# Test for availability of matplotlib.
# (And remove matplotlib as an iris.tests dependency.)
try:
import matplotlib

# Override any user settings e.g. from matplotlibrc file.
matplotlib.rcdefaults()
# Set backend *after* rcdefaults, as we don't want that overridden (#3846).
matplotlib.use("agg")
# Standardise the figure size across matplotlib versions.
# This permits matplotlib png image comparison.
matplotlib.rcParams["figure.figsize"] = [8.0, 6.0]
import matplotlib.pyplot as plt
except ImportError:
MPL_AVAILABLE = False
else:
MPL_AVAILABLE = True
MPL_AVAILABLE = graphics.MPL_AVAILABLE


try:
from osgeo import gdal # noqa
Expand Down Expand Up @@ -111,10 +87,6 @@

#: Basepath for test results.
_RESULT_PATH = os.path.join(os.path.dirname(__file__), "results")
#: Default perceptual hash size.
_HASH_SIZE = 16
#: Default maximum perceptual hash hamming distance.
_HAMMING_DISTANCE = 2

if "--data-files-used" in sys.argv:
sys.argv.remove("--data-files-used")
Expand All @@ -131,18 +103,6 @@
os.environ["IRIS_TEST_CREATE_MISSING"] = "true"


# Whether to display matplotlib output to the screen.
_DISPLAY_FIGURES = False

if MPL_AVAILABLE and "-d" in sys.argv:
sys.argv.remove("-d")
plt.switch_backend("tkagg")
_DISPLAY_FIGURES = True

# Threading non re-entrant blocking lock to ensure thread-safe plotting.
_lock = threading.Lock()


def main():
"""A wrapper for unittest.main() which adds iris.test specific options to the help (-h) output."""
if "-h" in sys.argv or "--help" in sys.argv:
Expand Down Expand Up @@ -179,43 +139,6 @@ def main():
unittest.main()


def get_data_path(relative_path):
"""
Return the absolute path to a data file when given the relative path
as a string, or sequence of strings.

"""
if not isinstance(relative_path, str):
relative_path = os.path.join(*relative_path)
test_data_dir = iris.config.TEST_DATA_DIR
if test_data_dir is None:
test_data_dir = ""
data_path = os.path.join(test_data_dir, relative_path)

if _EXPORT_DATAPATHS_FILE is not None:
_EXPORT_DATAPATHS_FILE.write(data_path + "\n")

if isinstance(data_path, str) and not os.path.exists(data_path):
# if the file is gzipped, ungzip it and return the path of the ungzipped
# file.
gzipped_fname = data_path + ".gz"
if os.path.exists(gzipped_fname):
with gzip.open(gzipped_fname, "rb") as gz_fh:
try:
with open(data_path, "wb") as fh:
fh.writelines(gz_fh)
except IOError:
# Put ungzipped data file in a temporary path, since we
# can't write to the original path (maybe it is owned by
# the system.)
_, ext = os.path.splitext(data_path)
data_path = iris.util.create_temp_filename(suffix=ext)
with open(data_path, "wb") as fh:
fh.writelines(gz_fh)

return data_path


class IrisTest_nometa(unittest.TestCase):
"""A subclass of unittest.TestCase which provides Iris specific testing functionality."""

Expand Down Expand Up @@ -250,6 +173,43 @@ def _assert_str_same(
% (type_comparison_name, reference_filename, diff)
)

@staticmethod
def get_data_path(relative_path):
"""
Return the absolute path to a data file when given the relative path
as a string, or sequence of strings.

"""
if not isinstance(relative_path, str):
relative_path = os.path.join(*relative_path)
test_data_dir = iris.config.TEST_DATA_DIR
if test_data_dir is None:
test_data_dir = ""
data_path = os.path.join(test_data_dir, relative_path)

if _EXPORT_DATAPATHS_FILE is not None:
_EXPORT_DATAPATHS_FILE.write(data_path + "\n")

if isinstance(data_path, str) and not os.path.exists(data_path):
# if the file is gzipped, ungzip it and return the path of the ungzipped
# file.
gzipped_fname = data_path + ".gz"
if os.path.exists(gzipped_fname):
with gzip.open(gzipped_fname, "rb") as gz_fh:
try:
with open(data_path, "wb") as fh:
fh.writelines(gz_fh)
except IOError:
# Put ungzipped data file in a temporary path, since we
# can't write to the original path (maybe it is owned by
# the system.)
_, ext = os.path.splitext(data_path)
data_path = iris.util.create_temp_filename(suffix=ext)
with open(data_path, "wb") as fh:
fh.writelines(gz_fh)

return data_path
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wjbenfold Nice, thanks for aligning get_data_path with get_result_path 👍

They have the same use case and yet through the mists of time one was kept as a function whilst the other as a static method. At least now they are in they are defined in the same namespace and used in the same way 😄


@staticmethod
def get_result_path(relative_path):
"""
Expand Down Expand Up @@ -872,137 +832,7 @@ def check_graphic(self):
output directory, and the imagerepo.json file being updated.

"""
from PIL import Image
import imagehash

dev_mode = os.environ.get("IRIS_TEST_CREATE_MISSING")
unique_id = self._unique_id()
repo_fname = os.path.join(_RESULT_PATH, "imagerepo.json")
with open(repo_fname, "rb") as fi:
repo: Dict[str, List[str]] = json.load(
codecs.getreader("utf-8")(fi)
)

try:
#: The path where the images generated by the tests should go.
image_output_directory = os.path.join(
os.path.dirname(__file__), "result_image_comparison"
)
if not os.access(image_output_directory, os.W_OK):
if not os.access(os.getcwd(), os.W_OK):
raise IOError(
"Write access to a local disk is required "
"to run image tests. Run the tests from a "
"current working directory you have write "
"access to to avoid this issue."
)
else:
image_output_directory = os.path.join(
os.getcwd(), "iris_image_test_output"
)
result_fname = os.path.join(
image_output_directory, "result-" + unique_id + ".png"
)

if not os.path.isdir(image_output_directory):
# Handle race-condition where the directories are
# created sometime between the check above and the
# creation attempt below.
try:
os.makedirs(image_output_directory)
except OSError as err:
# Don't care about "File exists"
if err.errno != 17:
raise

def _create_missing():
fname = "{}.png".format(phash)
base_uri = (
"https://scitools.github.io/test-iris-imagehash/"
"images/v4/{}"
)
uri = base_uri.format(fname)
hash_fname = os.path.join(image_output_directory, fname)
uris = repo.setdefault(unique_id, [])
uris.append(uri)
print("Creating image file: {}".format(hash_fname))
figure.savefig(hash_fname)
msg = "Creating imagerepo entry: {} -> {}"
print(msg.format(unique_id, uri))
lock = filelock.FileLock(
os.path.join(_RESULT_PATH, "imagerepo.lock")
)
# The imagerepo.json file is a critical resource, so ensure
# thread safe read/write behaviour via platform independent
# file locking.
with lock.acquire(timeout=600):
with open(repo_fname, "wb") as fo:
json.dump(
repo,
codecs.getwriter("utf-8")(fo),
indent=4,
sort_keys=True,
)

# Calculate the test result perceptual image hash.
buffer = io.BytesIO()
figure = plt.gcf()
figure.savefig(buffer, format="png")
buffer.seek(0)
phash = imagehash.phash(Image.open(buffer), hash_size=_HASH_SIZE)

if unique_id not in repo:
# The unique id might not be fully qualified, e.g.
# expects iris.tests.test_quickplot.TestLabels.test_contour.0,
# but got test_quickplot.TestLabels.test_contour.0
# if we find single partial match from end of the key
# then use that, else fall back to the unknown id state.
matches = [key for key in repo if key.endswith(unique_id)]
if len(matches) == 1:
unique_id = matches[0]

if unique_id in repo:
uris = repo[unique_id]
# Extract the hex basename strings from the uris.
hexes = [
os.path.splitext(os.path.basename(uri))[0] for uri in uris
]
# Create the expected perceptual image hashes from the uris.
to_hash = imagehash.hex_to_hash
expected = [to_hash(uri_hex) for uri_hex in hexes]

# Calculate hamming distance vector for the result hash.
distances = [e - phash for e in expected]

if np.all([hd > _HAMMING_DISTANCE for hd in distances]):
if dev_mode:
_create_missing()
else:
figure.savefig(result_fname)
msg = (
"Bad phash {} with hamming distance {} "
"for test {}."
)
msg = msg.format(phash, distances, unique_id)
if _DISPLAY_FIGURES:
emsg = "Image comparison would have failed: {}"
print(emsg.format(msg))
else:
emsg = "Image comparison failed: {}"
raise AssertionError(emsg.format(msg))
else:
if dev_mode:
_create_missing()
else:
figure.savefig(result_fname)
emsg = "Missing image test result: {}."
raise AssertionError(emsg.format(unique_id))

if _DISPLAY_FIGURES:
plt.show()

finally:
plt.close()
graphics.check_graphic(self)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wjbenfold Interesting 🤔

I understand the motivation to corral all the graphic related infrastructure into one module 👍

However this one line is kinda strongly hinting that the definition of check_graphic belongs within the class as a method and not as a method simply wrapping a function. Also, the graphics.check_graphic function isn't called directly from anywhere else in the codebase (that I can see, so correctly me if I'm wrong here), which would bolster the argument of detaching it and making it a function.

I'd personally rather go one way or the other and not hybrid like this, which seems like an artificial indirection (again, although I understand why you've done it).


def _remove_testcase_patches(self):
"""Helper to remove per-testcase patches installed by :meth:`patch`."""
Expand Down Expand Up @@ -1214,37 +1044,15 @@ class IrisTest(IrisTest_nometa, metaclass=_TestTimingsMetaclass):
pass


get_data_path = IrisTest.get_data_path
get_result_path = IrisTest.get_result_path


class GraphicsTestMixin:

# nose directive: dispatch tests concurrently.
_multiprocess_can_split_ = True

def setUp(self):
# Acquire threading non re-entrant blocking lock to ensure
# thread-safe plotting.
_lock.acquire()
# Make sure we have no unclosed plots from previous tests before
# generating this one.
if MPL_AVAILABLE:
plt.close("all")

def tearDown(self):
# If a plotting test bombs out it can leave the current figure
# in an odd state, so we make sure it's been disposed of.
if MPL_AVAILABLE:
plt.close("all")
# Release the non re-entrant blocking lock.
_lock.release()


class GraphicsTest(GraphicsTestMixin, IrisTest):
class GraphicsTest(graphics.GraphicsTestMixin, IrisTest):
pass


class GraphicsTest_nometa(GraphicsTestMixin, IrisTest_nometa):
class GraphicsTest_nometa(graphics.GraphicsTestMixin, IrisTest_nometa):
# Graphicstest without the metaclass providing test timings.
pass

Expand Down Expand Up @@ -1290,23 +1098,7 @@ class MyGeoTiffTests(test.IrisTest):
return skip(fn)


def skip_plot(fn):
"""
Decorator to choose whether to run tests, based on the availability of the
matplotlib library.

Example usage:
@skip_plot
class MyPlotTests(test.GraphicsTest):
...

"""
skip = unittest.skipIf(
condition=not MPL_AVAILABLE,
reason="Graphics tests require the matplotlib library.",
)

return skip(fn)
skip_plot = graphics.skip_plot


skip_sample_data = unittest.skipIf(
Expand Down
Loading