From ffcb4de4a6bc03ff0523a3bb150e102577325b68 Mon Sep 17 00:00:00 2001 From: Elias Sadek Date: Mon, 11 Mar 2024 10:37:34 +0000 Subject: [PATCH 1/4] added pytest iris class --- lib/iris/tests/iris_test_pytest.py | 1033 ++++++++++++++++++++++++++++ 1 file changed, 1033 insertions(+) create mode 100644 lib/iris/tests/iris_test_pytest.py diff --git a/lib/iris/tests/iris_test_pytest.py b/lib/iris/tests/iris_test_pytest.py new file mode 100644 index 0000000000..6a1b36b38d --- /dev/null +++ b/lib/iris/tests/iris_test_pytest.py @@ -0,0 +1,1033 @@ +# Copyright Iris contributors +# +# This file is part of Iris and is released under the BSD license. +# See LICENSE in the root of the repository for full licensing details. +"""Provides testing capabilities and customisations specific to Iris. + +.. note:: This module needs to control the matplotlib backend, so it + **must** be imported before ``matplotlib.pyplot``. + +The primary class for this module is :class:`IrisTest`. + +""" + +import collections +from collections.abc import Mapping +import contextlib +import difflib +import filecmp +import functools +import gzip +import inspect +import io +import json +import math +import os +import os.path +from pathlib import Path +import re +import shutil +import subprocess +import sys +from typing import AnyStr + +# from unittest import mock +import warnings +import xml.dom.minidom +import zlib + +import numpy as np +import numpy.ma as ma +import pytest +import requests + +import iris.config +import iris.cube +import iris.fileformats +import iris.tests.graphics as graphics +import iris.util + +MPL_AVAILABLE = graphics.MPL_AVAILABLE + + +try: + from osgeo import gdal # noqa +except ImportError: + GDAL_AVAILABLE = False +else: + GDAL_AVAILABLE = True + +try: + import iris_sample_data # noqa +except ImportError: + SAMPLE_DATA_AVAILABLE = False +else: + SAMPLE_DATA_AVAILABLE = True + +try: + import nc_time_axis # noqa + + NC_TIME_AXIS_AVAILABLE = True +except ImportError: + NC_TIME_AXIS_AVAILABLE = False + +try: + # Added a timeout to stop the call to requests.get hanging when running + # on a platform which has restricted/no internet access. + requests.get("https://github.com/SciTools/iris", timeout=10.0) + INET_AVAILABLE = True +except requests.exceptions.ConnectionError: + INET_AVAILABLE = False + +try: + import stratify # noqa + + STRATIFY_AVAILABLE = True +except ImportError: + STRATIFY_AVAILABLE = False + +#: Basepath for test results. +_RESULT_PATH = os.path.join(os.path.dirname(__file__), "results") + +if "--data-files-used" in sys.argv: + sys.argv.remove("--data-files-used") + fname = "/var/tmp/all_iris_test_resource_paths.txt" + print("saving list of files used by tests to %s" % fname) + _EXPORT_DATAPATHS_FILE = open(fname, "w") +else: + _EXPORT_DATAPATHS_FILE = None + + +if "--create-missing" in sys.argv: + sys.argv.remove("--create-missing") + print("Allowing creation of missing test results.") + os.environ["IRIS_TEST_CREATE_MISSING"] = "true" + + +def main(): + """A wrapper for pytest.main() which adds iris.test specific options to the help (-h) output.""" + if "-h" in sys.argv or "--help" in sys.argv: + stdout = sys.stdout + buff = io.StringIO() + # NB. pytest.main() raises an exception after it's shown the help text + try: + sys.stdout = buff + pytest.main() + finally: + sys.stdout = stdout + lines = buff.getvalue().split("\n") + lines.insert(9, "Iris-specific options:") + lines.insert( + 10, + " -d Display matplotlib figures (uses tkagg).", + ) + lines.insert( + 11, + " NOTE: To compare results of failing tests, ", + ) + lines.insert(12, " use idiff.py instead") + lines.insert( + 13, + " --data-files-used Save a list of files used to a temporary file", + ) + lines.insert(14, " -m Create missing test results") + print("\n".join(lines)) + else: + pytest.main() + + +def _assert_masked_array(assertion, a, b, strict, **kwargs): + # Compare masks. + a_mask, b_mask = ma.getmaskarray(a), ma.getmaskarray(b) + np.testing.assert_array_equal(a_mask, b_mask) # pytest already? + + if strict: + # Compare all data values. + assertion(a.data, b.data, **kwargs) + else: + # Compare only unmasked data values. + assertion( + ma.compressed(a), + ma.compressed(b), + **kwargs, + ) + + +# np should be fine +def assert_masked_array_equal(a, b, strict=False): + """Check that masked arrays are equal. This requires the + unmasked values and masks to be identical. + + Parameters + ---------- + a, b : array-like + Two arrays to compare. + strict : bool, optional + If True, perform a complete mask and data array equality check. + If False (default), the data array equality considers only unmasked + elements. + + """ + _assert_masked_array(np.testing.assert_array_equal, a, b, strict) + + +def assert_masked_array_almost_equal(a, b, decimal=6, strict=False): + """Check that masked arrays are almost equal. This requires the + masks to be identical, and the unmasked values to be almost + equal. + + Parameters + ---------- + a, b : array-like + Two arrays to compare. + strict : bool, optional + If True, perform a complete mask and data array equality check. + If False (default), the data array equality considers only unmasked + elements. + decimal : int, optional, default=6 + Equality tolerance level for + :meth:`numpy.testing.assert_array_almost_equal`, with the meaning + 'abs(desired-actual) < 0.5 * 10**(-decimal)' + + """ + _assert_masked_array( + np.testing.assert_array_almost_equal, a, b, strict, decimal=decimal + ) + + +# seems fine +def _assert_str_same( + reference_str, + test_str, + reference_filename, + type_comparison_name="Strings", +): + diff = "".join( + difflib.unified_diff( + reference_str.splitlines(1), + test_str.splitlines(1), + "Reference", + "Test result", + "", + "", + 0, + ) + ) + fail_string = ( + f"{type_comparison_name} do not match: {reference_filename}\n" f"{diff}" + ) + assert reference_str == test_str, fail_string + + +_assertion_counts = collections.defaultdict(int) + + +def get_data_path(relative_path): + """Return the absolute path to a data file when given the relative path + as a string, or sequence of strings. + + """ + if not isinstance(relative_path, str): + relative_path = os.path.join(*relative_path) + test_data_dir = iris.config.TEST_DATA_DIR + if test_data_dir is None: + test_data_dir = "" + data_path = os.path.join(test_data_dir, relative_path) + + if _EXPORT_DATAPATHS_FILE is not None: + _EXPORT_DATAPATHS_FILE.write(data_path + "\n") + + if isinstance(data_path, str) and not os.path.exists(data_path): + # if the file is gzipped, ungzip it and return the path of the ungzipped + # file. + gzipped_fname = data_path + ".gz" + if os.path.exists(gzipped_fname): + with gzip.open(gzipped_fname, "rb") as gz_fh: + try: + with open(data_path, "wb") as fh: + fh.writelines(gz_fh) + except IOError: + # Put ungzipped data file in a temporary path, since we + # can't write to the original path (maybe it is owned by + # the system.) + _, ext = os.path.splitext(data_path) + data_path = iris.util.create_temp_filename(suffix=ext) + with open(data_path, "wb") as fh: + fh.writelines(gz_fh) + + return data_path + + +def get_result_path(relative_path): + """Returns the absolute path to a result file when given the relative path + as a string, or sequence of strings. + + """ + if not isinstance(relative_path, str): + relative_path = os.path.join(*relative_path) + return os.path.abspath(os.path.join(_RESULT_PATH, relative_path)) + + +def result_path(self, basename=None, ext=""): + """Return the full path to a test result, generated from the \ + calling file, class and, optionally, method. + + Parameters + ---------- + basename : optional, default=None + File basename. If omitted, this is generated from the calling method. + ext : str, optional, default="" + Appended file extension. + + """ + if ext and not ext.startswith("."): + ext = "." + ext + + # Generate the folder name from the calling file name. + path = os.path.abspath(inspect.getfile(self.__class__)) + path = os.path.splitext(path)[0] + sub_path = path.rsplit("iris", 1)[1].split("tests", 1)[1][1:] + + # Generate the file name from the calling function name? + if basename is None: + stack = inspect.stack() + for frame in stack[1:]: + if "test_" in frame[3]: + basename = frame[3].replace("test_", "") + break + filename = basename + ext + + result = os.path.join( + self.get_result_path(""), + sub_path.replace("test_", ""), + self.__class__.__name__.replace("Test_", ""), + filename, + ) + return result + + +def assertCMLApproxData(cubes, reference_filename=None, **kwargs): + # passes args and kwargs on to approx equal + if isinstance(cubes, iris.cube.Cube): + cubes = [cubes] + if reference_filename is None: + reference_filename = result_path(None, "cml") + reference_filename = [get_result_path(reference_filename)] + for i, cube in enumerate(cubes): + fname = list(reference_filename) + # don't want the ".cml" for the json stats file + if fname[-1].endswith(".cml"): + fname[-1] = fname[-1][:-4] + fname[-1] += ".data.%d.json" % i + assertDataAlmostEqual(cube.data, fname, **kwargs) + assertCML(cubes, reference_filename, checksum=False) + + +def assertCDL(netcdf_filename, reference_filename=None, flags="-h"): + """Test that the CDL for the given netCDF file matches the contents + of the reference file. + + If the environment variable IRIS_TEST_CREATE_MISSING is + non-empty, the reference file is created if it doesn't exist. + + Parameters + ---------- + netcdf_filename : + The path to the netCDF file. + reference_filename : optional, default=None + The relative path (relative to the test results directory). + If omitted, the result is generated from the calling + method's name, class, and module using + :meth:`iris.tests.IrisTest.result_path`. + flags : str, optional + Command-line flags for `ncdump`, as either a whitespace + separated string or an iterable. Defaults to '-h'. + + """ + if reference_filename is None: + reference_path = result_path(None, "cdl") + else: + reference_path = get_result_path(reference_filename) + + # Convert the netCDF file to CDL file format. + if flags is None: + flags = [] + elif isinstance(flags, str): + flags = flags.split() + else: + flags = list(map(str, flags)) + + try: + exe_path = env_bin_path("ncdump") + args = [exe_path] + flags + [netcdf_filename] + cdl = subprocess.check_output(args, stderr=subprocess.STDOUT) + except subprocess.CalledProcessError as exc: + print(exc.output) + raise + + # Ingest the CDL for comparison, excluding first line. + lines = cdl.decode("ascii").splitlines() + lines = lines[1:] + + # Ignore any lines of the general form "... :_NCProperties = ..." + # (an extra global attribute, displayed by older versions of ncdump). + re_ncprop = re.compile(r"^\s*:_NCProperties *=") + lines = [line for line in lines if not re_ncprop.match(line)] + + # Sort the dimensions (except for the first, which can be unlimited). + # This gives consistent CDL across different platforms. + def sort_key(line): + return ("UNLIMITED" not in line, line) + + dimension_lines = slice(lines.index("dimensions:") + 1, lines.index("variables:")) + lines[dimension_lines] = sorted(lines[dimension_lines], key=sort_key) + cdl = "\n".join(lines) + "\n" + + _check_same(cdl, reference_path, type_comparison_name="CDL") + + +def assertCML(cubes, reference_filename=None, checksum=True): + """Test that the CML for the given cubes matches the contents of + the reference file. + + If the environment variable IRIS_TEST_CREATE_MISSING is + non-empty, the reference file is created if it doesn't exist. + + Parameters + ---------- + cubes : + Either a Cube or a sequence of Cubes. + reference_filename : optional, default=None + The relative path (relative to the test results directory). + If omitted, the result is generated from the calling + method's name, class, and module using + :meth:`iris.tests.IrisTest.result_path`. + checksum : bool, optional + When True, causes the CML to include a checksum for each + Cube's data. Defaults to True. + + """ + if isinstance(cubes, iris.cube.Cube): + cubes = [cubes] + if reference_filename is None: + reference_filename = result_path(None, "cml") + + if isinstance(cubes, (list, tuple)): + xml = iris.cube.CubeList(cubes).xml( + checksum=checksum, order=False, byteorder=False + ) + else: + xml = cubes.xml(checksum=checksum, order=False, byteorder=False) + reference_path = get_result_path(reference_filename) + _check_same(xml, reference_path) + + +def assertTextFile(source_filename, reference_filename, desc="text file"): + """Check if two text files are the same, printing any diffs.""" + with open(source_filename) as source_file: + source_text = source_file.readlines() + with open(reference_filename) as reference_file: + reference_text = reference_file.readlines() + + diff = "".join( + difflib.unified_diff( + reference_text, + source_text, + "Reference", + "Test result", + "", + "", + 0, + ) + ) + fail_string = ( + f"{desc} does not match: reference file " f"{reference_filename} \n {diff}" + ) + assert reference_text == source_text, fail_string + + +def assertDataAlmostEqual(data, reference_filename, **kwargs): + reference_path = get_result_path(reference_filename) + if _check_reference_file(reference_path): + kwargs.setdefault("err_msg", "Reference file %s" % reference_path) + with open(reference_path, "r") as reference_file: + stats = json.load(reference_file) + assert stats.get("shape", []), list(data.shape) + assert stats.get("masked", False), ma.is_masked(data) + nstats = np.array( + ( + stats.get("mean", 0.0), + stats.get("std", 0.0), + stats.get("max", 0.0), + stats.get("min", 0.0), + ), + dtype=np.float64, + ) + if math.isnan(stats.get("mean", 0.0)): + assert math.isnan(data.mean()) + else: + data_stats = np.array( + (data.mean(), data.std(), data.max(), data.min()), + dtype=np.float64, + ) + assertArrayAllClose(nstats, data_stats, **kwargs) + else: + _ensure_folder(reference_path) + stats = collections.OrderedDict( + [ + ("std", np.float64(data.std())), + ("min", np.float64(data.min())), + ("max", np.float64(data.max())), + ("shape", data.shape), + ("masked", ma.is_masked(data)), + ("mean", np.float64(data.mean())), + ] + ) + with open(reference_path, "w") as reference_file: + reference_file.write(json.dumps(stats)) + + +def assertFilesEqual(test_filename, reference_filename): + reference_path = get_result_path(reference_filename) + if _check_reference_file(reference_path): + fmt = "test file {!r} does not match reference {!r}." + assert filecmp.cmp(test_filename, reference_path) and fmt.format( + test_filename, reference_path + ) + else: + _ensure_folder(reference_path) + shutil.copy(test_filename, reference_path) + + +def assertString(string, reference_filename=None): + """Test that `string` matches the contents of the reference file. + + If the environment variable IRIS_TEST_CREATE_MISSING is + non-empty, the reference file is created if it doesn't exist. + + Parameters + ---------- + string : str + The string to check. + reference_filename : optional, default=None + The relative path (relative to the test results directory). + If omitted, the result is generated from the calling + method's name, class, and module using + :meth:`iris.tests.IrisTest.result_path`. + + """ + if reference_filename is None: + reference_path = result_path(None, "txt") + else: + reference_path = get_result_path(reference_filename) + _check_same(string, reference_path, type_comparison_name="Strings") + + +def assertRepr(obj, reference_filename): + assertString(repr(obj), reference_filename) + + +def _check_same(item, reference_path, type_comparison_name="CML"): + if _check_reference_file(reference_path): + with open(reference_path, "rb") as reference_fh: + reference = "".join( + part.decode("utf-8") for part in reference_fh.readlines() + ) + _assert_str_same(reference, item, reference_path, type_comparison_name) + else: + _ensure_folder(reference_path) + with open(reference_path, "wb") as reference_fh: + reference_fh.writelines(part.encode("utf-8") for part in item) + + +def assertXMLElement(obj, reference_filename): + """Calls the xml_element method given obj and asserts the result is the same as the test file.""" + doc = xml.dom.minidom.Document() + doc.appendChild(obj.xml_element(doc)) + # sort the attributes on xml elements before testing against known good state. + # this is to be compatible with stored test output where xml attrs are stored in alphabetical order, + # (which was default behaviour in python <3.8, but changed to insert order in >3.8) + doc = iris.cube.Cube._sort_xml_attrs(doc) + pretty_xml = doc.toprettyxml(indent=" ") + reference_path = get_result_path(reference_filename) + _check_same(pretty_xml, reference_path, type_comparison_name="XML") + + +def assertArrayEqual(a, b, err_msg=""): + np.testing.assert_array_equal(a, b, err_msg=err_msg) + + +@contextlib.contextmanager +def _recordWarningMatches(expected_regexp=""): + # Record warnings raised matching a given expression. + matches = [] + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + yield matches + messages = [str(warning.message) for warning in w] + expr = re.compile(expected_regexp) + matches.extend(message for message in messages if expr.search(message)) + + +@contextlib.contextmanager +def assertLogs(logger=None, level=None, msg_regex=None): + """An extended version of the usual :meth:`unittest.TestCase.assertLogs`, + which also exercises the logger's message formatting. + + Also adds the ``msg_regex`` kwarg: + If used, check that the result is a single message of the specified + level, and that it matches this regex. + + The inherited version of this method temporarily *replaces* the logger + in order to capture log records generated within the context. + However, in doing so it prevents any messages from being formatted + by the original logger. + This version first calls the original method, but then *also* exercises + the message formatters of all the logger's handlers, just to check that + there are no formatting errors. + + """ + # Invoke the standard assertLogs behaviour. + assertlogging_context = super().assertLogs(logger, level) + with assertlogging_context as watcher: + # Run the caller context, as per original method. + yield watcher + # Check for any formatting errors by running all the formatters. + for record in watcher.records: + for handler in assertlogging_context.logger.handlers: + handler.format(record) + + # Check message, if requested. + if msg_regex: + assert len(watcher.records) == 1 + rec = watcher.records[0] + assert level == rec.levelname + assert rec.msg == msg_regex + + +@contextlib.contextmanager +def assertNoWarningsRegexp(expected_regexp=""): + # Check that no warning matching the given expression is raised. + with _recordWarningMatches(expected_regexp) as matches: + yield + + msg = "Unexpected warning(s) raised, matching '{}' : {!r}." + msg = msg.format(expected_regexp, matches) + assert not (matches, msg) + + +assertMaskedArrayEqual = staticmethod(assert_masked_array_equal) + + +def assertArrayAlmostEqual(a, b, decimal=6): + np.testing.assert_array_almost_equal(a, b, decimal=decimal) + + +assertMaskedArrayAlmostEqual = assert_masked_array_almost_equal + + +def assertArrayAllClose(a, b, rtol=1.0e-7, atol=1.0e-8, **kwargs): + """Check arrays are equal, within given relative + absolute tolerances. + + Parameters + ---------- + a, b : array-like + Two arrays to compare. + rtol, atol : float, optional + Relative and absolute tolerances to apply. + + Other Parameters + ---------------- + Any additional kwargs are passed to numpy.testing.assert_allclose. + + Performs pointwise toleranced comparison, and raises an assertion if + the two are not equal 'near enough'. + For full details see underlying routine numpy.allclose. + + """ + # Handle the 'err_msg' kwarg, which is the only API difference + # between np.allclose and np.testing_assert_allclose. + msg = kwargs.pop("err_msg", None) + ok = np.allclose(a, b, rtol=rtol, atol=atol, **kwargs) + if not ok: + # Calculate errors above a pointwise tolerance : The method is + # taken from "numpy.core.numeric.isclose". + a, b = np.broadcast_arrays(a, b) + errors = np.abs(a - b) - atol + rtol * np.abs(b) + worst_inds = np.unravel_index(np.argmax(errors.flat), errors.shape) + + if msg is None: + # Build a more useful message than np.testing.assert_allclose. + msg = ( + '\nARRAY CHECK FAILED "assertArrayAllClose" :' + "\n with shapes={} {}, atol={}, rtol={}" + "\n worst at element {} : a={} b={}" + "\n absolute error ~{:.3g}, equivalent to rtol ~{:.3e}" + ) + aval, bval = a[worst_inds], b[worst_inds] + absdiff = np.abs(aval - bval) + equiv_rtol = absdiff / bval + msg = msg.format( + a.shape, + b.shape, + atol, + rtol, + worst_inds, + aval, + bval, + absdiff, + equiv_rtol, + ) + + raise AssertionError(msg) + + +@contextlib.contextmanager +def temp_filename(suffix=""): + filename = iris.util.create_temp_filename(suffix) + try: + yield filename + finally: + os.remove(filename) + + +def file_checksum(file_path): + """Generate checksum from file.""" + with open(file_path, "rb") as in_file: + return zlib.crc32(in_file.read()) + + +def _unique_id(): + """Returns the unique ID for the current assertion. + + The ID is composed of two parts: a unique ID for the current test + (which is itself composed of the module, class, and test names), and + a sequential counter (specific to the current test) that is incremented + on each call. + + For example, calls from a "test_tx" routine followed by a "test_ty" + routine might result in:: + test_plot.TestContourf.test_tx.0 + test_plot.TestContourf.test_tx.1 + test_plot.TestContourf.test_tx.2 + test_plot.TestContourf.test_ty.0 + + """ + # Obtain a consistent ID for the current test. + # NB. unittest.TestCase.id() returns different values depending on + # whether the test has been run explicitly, or via test discovery. + # For example: + # python tests/test_plot.py => '__main__.TestContourf.test_tx' + # ird -t => 'iris.tests.test_plot.TestContourf.test_tx' + bits = id().split(".") + if bits[0] == "__main__": + floc = sys.modules["__main__"].__file__ + path, file_name = os.path.split(os.path.abspath(floc)) + bits[0] = os.path.splitext(file_name)[0] + folder, location = os.path.split(path) + bits = [location] + bits + while location not in ["iris", "gallery_tests"]: + folder, location = os.path.split(folder) + bits = [location] + bits + test_id = ".".join(bits) + + # Derive the sequential assertion ID within the test + assertion_id = _assertion_counts[test_id] + _assertion_counts[test_id] += 1 + + return test_id + "." + str(assertion_id) + + +def _check_reference_file(reference_path): + reference_exists = os.path.isfile(reference_path) + if not (reference_exists or os.environ.get("IRIS_TEST_CREATE_MISSING")): + msg = "Missing test result: {}".format(reference_path) + raise AssertionError(msg) + return reference_exists + + +def _ensure_folder(path): + dir_path = os.path.dirname(path) + if not os.path.exists(dir_path): + os.makedirs(dir_path) + + +def check_graphic(): + """Check the hash of the current matplotlib figure matches the expected + image hash for the current graphic test. + + To create missing image test results, set the IRIS_TEST_CREATE_MISSING + environment variable before running the tests. This will result in new + and appropriately ".png" image files being generated in the image + output directory, and the imagerepo.json file being updated. + + """ + graphics.check_graphic( + _unique_id(), + _RESULT_PATH, + ) + + +def _remove_testcase_patches(testcase_patches): + """Helper to remove per-testcase patches installed by :meth:`patch`.""" + # Remove all patches made, ignoring errors. + for p in testcase_patches: + p.stop() + # Reset per-test patch control variable. + testcase_patches.clear() + + +def patch(*args, **kwargs): + """Install a mock.patch, to be removed after the current test. + + The patch is created with mock.patch(*args, **kwargs). + + Returns + ------- + The substitute object returned by patch.start(). + + Examples + -------- + :: + + mock_call = self.patch('module.Class.call', return_value=1) + module_Class_instance.call(3, 4) + self.assertEqual(mock_call.call_args_list, [mock.call(3, 4)]) + + """ + # Make the new patch and start it. + patch = pytest.mock.patch(*args, **kwargs) + start_result = patch.start() + + # Create the per-testcases control variable if it does not exist. + # NOTE: this mimics a setUp method, but continues to work when a + # subclass defines its own setUp. + if not hasattr("testcase_patches"): + testcase_patches = {} + + # When installing the first patch, schedule remove-all at cleanup. + if not testcase_patches: + _remove_testcase_patches(testcase_patches) + + # Record the new patch and start object for reference. + testcase_patches[patch] = start_result + + # Return patch replacement object. + return start_result + + +def assertArrayShapeStats(result, shape, mean, std_dev, rtol=1e-6): + """Assert that the result, a cube, has the provided shape and that the + mean and standard deviation of the data array are also as provided. + Thus build confidence that a cube processing operation, such as a + cube.regrid, has maintained its behaviour. + + """ + assert result.shape == shape + assertArrayAllClose(result.data.mean(), mean, rtol=rtol) + assertArrayAllClose(result.data.std(), std_dev, rtol=rtol) + + +def assertDictEqual(lhs, rhs, msg=None): + """Dictionary Comparison. + + This method overrides unittest.TestCase.assertDictEqual (new in Python3.1) + in order to cope with dictionary comparison where the value of a key may + be a numpy array. + + """ + emsg = f"Provided LHS argument is not a 'Mapping', got {type(lhs)}." + assert isinstance(lhs, Mapping), emsg + + emsg = f"Provided RHS argument is not a 'Mapping', got {type(rhs)}." + assert isinstance(rhs, Mapping), emsg + + emsg = f"{lhs!r} != {rhs!r}." + assert set(lhs.keys()) == set(rhs.keys()), emsg + + for key in lhs.keys(): + lvalue, rvalue = lhs[key], rhs[key] + + if ma.isMaskedArray(lvalue) or ma.isMaskedArray(rvalue): + if not ma.isMaskedArray(lvalue): + emsg = ( + f"Dictionary key {key!r} values are not equal, " + f"the LHS value has type {type(lvalue)} and " + f"the RHS value has type {ma.core.MaskedArray}." + ) + raise AssertionError(emsg) + + if not ma.isMaskedArray(rvalue): + emsg = ( + f"Dictionary key {key!r} values are not equal, " + f"the LHS value has type {ma.core.MaskedArray} and " + f"the RHS value has type {type(lvalue)}." + ) + raise AssertionError(emsg) + + assertMaskedArrayEqual(lvalue, rvalue) + elif isinstance(lvalue, np.ndarray) or isinstance(rvalue, np.ndarray): + if not isinstance(lvalue, np.ndarray): + emsg = ( + f"Dictionary key {key!r} values are not equal, " + f"the LHS value has type {type(lvalue)} and " + f"the RHS value has type {np.ndarray}." + ) + raise AssertionError(emsg) + + if not isinstance(rvalue, np.ndarray): + emsg = ( + f"Dictionary key {key!r} values are not equal, " + f"the LHS value has type {np.ndarray} and " + f"the RHS value has type {type(rvalue)}." + ) + raise AssertionError(emsg) + + assertArrayEqual(lvalue, rvalue) + else: + if lvalue != rvalue: + emsg = ( + f"Dictionary key {key!r} values are not equal, " + f"{lvalue!r} != {rvalue!r}." + ) + raise AssertionError(emsg) + + +def assertEqualAndKind(value, expected): + # Check a value, and also its type 'kind' = float/integer/string. + assert value == expected + assert np.array(value).dtype.kind == np.array(expected).dtype.kind + + +class PPTest: + """A mixin class to provide PP-specific utilities to subclasses of tests.IrisTest.""" + + @contextlib.contextmanager + def cube_save_test( + self, + reference_txt_path, + reference_cubes=None, + reference_pp_path=None, + **kwargs, + ): + """A context manager for testing the saving of Cubes to PP files. + + Args: + + * reference_txt_path: + The path of the file containing the textual PP reference data. + + Kwargs: + + * reference_cubes: + The cube(s) from which the textual PP reference can be re-built if necessary. + * reference_pp_path: + The location of a PP file from which the textual PP reference can be re-built if necessary. + NB. The "reference_cubes" argument takes precedence over this argument. + + The return value from the context manager is the name of a temporary file + into which the PP data to be tested should be saved. + + Example:: + with self.cube_save_test(reference_txt_path, reference_cubes=cubes) as temp_pp_path: + iris.save(cubes, temp_pp_path) + + """ + # Watch out for a missing reference text file + if not os.path.isfile(reference_txt_path): + if reference_cubes: + temp_pp_path = iris.util.create_temp_filename(".pp") + try: + iris.save(reference_cubes, temp_pp_path, **kwargs) + self._create_reference_txt(reference_txt_path, temp_pp_path) + finally: + os.remove(temp_pp_path) + elif reference_pp_path: + self._create_reference_txt(reference_txt_path, reference_pp_path) + else: + raise ValueError( + "Missing all of reference txt file, cubes, and PP path." + ) + + temp_pp_path = iris.util.create_temp_filename(".pp") + try: + # This value is returned to the target of the "with" statement's "as" clause. + yield temp_pp_path + + # Load deferred data for all of the fields (but don't do anything with it) + pp_fields = list(iris.fileformats.pp.load(temp_pp_path)) + for pp_field in pp_fields: + pp_field.data + with open(reference_txt_path, "r") as reference_fh: + reference = "".join(reference_fh) + _assert_str_same( + reference + "\n", + str(pp_fields) + "\n", + reference_txt_path, + type_comparison_name="PP files", + ) + finally: + os.remove(temp_pp_path) + + def _create_reference_txt(self, txt_path, pp_path): + # Load the reference data + pp_fields = list(iris.fileformats.pp.load(pp_path)) + for pp_field in pp_fields: + pp_field.data + + # Clear any header words we don't use + unused = ("lbexp", "lbegin", "lbnrec", "lbproj", "lbtyp") + for pp_field in pp_fields: + for word_name in unused: + setattr(pp_field, word_name, 0) + + # Save the textual representation of the PP fields + with open(txt_path, "w") as txt_file: + txt_file.writelines(str(pp_fields)) + + +skip_plot = graphics.skip_plot + + +def no_warnings(func): + """Provides a decorator to ensure that there are no warnings raised + within the test, otherwise the test will fail. + + """ + + @functools.wraps(func) + def wrapped(self, *args, **kwargs): + with pytest.mock.patch("warnings.warn") as warn: + result = func(self, *args, **kwargs) + assert 0 == warn.call_count, "Got unexpected warnings.\n{}".format( + warn.call_args_list + ) + return result + + return wrapped + + +def env_bin_path(exe_name: AnyStr = None): + """Return a Path object for (an executable in) the environment bin directory. + + Parameters + ---------- + exe_name : str + If set, the name of an executable to append to the path. + + Returns + ------- + exe_path : Path + A path to the bin directory, or an executable file within it. + + Notes + ----- + For use in tests which spawn commands which should call executables within + the Python environment, since many IDEs (Eclipse, PyCharm) don't + automatically include this location in $PATH (as opposed to $PYTHONPATH). + """ + exe_path = Path(os.__file__) + exe_path = (exe_path / "../../../bin").resolve() + if exe_name is not None: + exe_path = exe_path / exe_name + return exe_path From 7942a2b856f8081570670c100637ea84e55f69e7 Mon Sep 17 00:00:00 2001 From: Elias Sadek Date: Mon, 11 Mar 2024 14:57:12 +0000 Subject: [PATCH 2/4] actioned majority of review comments --- .../{iris_test_pytest.py => _shared_utils.py} | 492 +++++++----------- 1 file changed, 194 insertions(+), 298 deletions(-) rename lib/iris/tests/{iris_test_pytest.py => _shared_utils.py} (68%) diff --git a/lib/iris/tests/iris_test_pytest.py b/lib/iris/tests/_shared_utils.py similarity index 68% rename from lib/iris/tests/iris_test_pytest.py rename to lib/iris/tests/_shared_utils.py index 6a1b36b38d..7591b61ac8 100644 --- a/lib/iris/tests/iris_test_pytest.py +++ b/lib/iris/tests/_shared_utils.py @@ -2,14 +2,7 @@ # # This file is part of Iris and is released under the BSD license. # See LICENSE in the root of the repository for full licensing details. -"""Provides testing capabilities and customisations specific to Iris. - -.. note:: This module needs to control the matplotlib backend, so it - **must** be imported before ``matplotlib.pyplot``. - -The primary class for this module is :class:`IrisTest`. - -""" +"""Provides testing capabilities and customisations specific to Iris.""" import collections from collections.abc import Mapping @@ -18,8 +11,8 @@ import filecmp import functools import gzip -import inspect -import io + +# import inspect import json import math import os @@ -28,10 +21,7 @@ import re import shutil import subprocess -import sys from typing import AnyStr - -# from unittest import mock import warnings import xml.dom.minidom import zlib @@ -44,6 +34,7 @@ import iris.config import iris.cube import iris.fileformats +import iris.tests import iris.tests.graphics as graphics import iris.util @@ -89,52 +80,6 @@ #: Basepath for test results. _RESULT_PATH = os.path.join(os.path.dirname(__file__), "results") -if "--data-files-used" in sys.argv: - sys.argv.remove("--data-files-used") - fname = "/var/tmp/all_iris_test_resource_paths.txt" - print("saving list of files used by tests to %s" % fname) - _EXPORT_DATAPATHS_FILE = open(fname, "w") -else: - _EXPORT_DATAPATHS_FILE = None - - -if "--create-missing" in sys.argv: - sys.argv.remove("--create-missing") - print("Allowing creation of missing test results.") - os.environ["IRIS_TEST_CREATE_MISSING"] = "true" - - -def main(): - """A wrapper for pytest.main() which adds iris.test specific options to the help (-h) output.""" - if "-h" in sys.argv or "--help" in sys.argv: - stdout = sys.stdout - buff = io.StringIO() - # NB. pytest.main() raises an exception after it's shown the help text - try: - sys.stdout = buff - pytest.main() - finally: - sys.stdout = stdout - lines = buff.getvalue().split("\n") - lines.insert(9, "Iris-specific options:") - lines.insert( - 10, - " -d Display matplotlib figures (uses tkagg).", - ) - lines.insert( - 11, - " NOTE: To compare results of failing tests, ", - ) - lines.insert(12, " use idiff.py instead") - lines.insert( - 13, - " --data-files-used Save a list of files used to a temporary file", - ) - lines.insert(14, " -m Create missing test results") - print("\n".join(lines)) - else: - pytest.main() - def _assert_masked_array(assertion, a, b, strict, **kwargs): # Compare masks. @@ -153,7 +98,6 @@ def _assert_masked_array(assertion, a, b, strict, **kwargs): ) -# np should be fine def assert_masked_array_equal(a, b, strict=False): """Check that masked arrays are equal. This requires the unmasked values and masks to be identical. @@ -195,7 +139,6 @@ def assert_masked_array_almost_equal(a, b, decimal=6, strict=False): ) -# seems fine def _assert_str_same( reference_str, test_str, @@ -219,9 +162,6 @@ def _assert_str_same( assert reference_str == test_str, fail_string -_assertion_counts = collections.defaultdict(int) - - def get_data_path(relative_path): """Return the absolute path to a data file when given the relative path as a string, or sequence of strings. @@ -234,8 +174,8 @@ def get_data_path(relative_path): test_data_dir = "" data_path = os.path.join(test_data_dir, relative_path) - if _EXPORT_DATAPATHS_FILE is not None: - _EXPORT_DATAPATHS_FILE.write(data_path + "\n") + if iris.tests._EXPORT_DATAPATHS_FILE is not None: + iris.tests._EXPORT_DATAPATHS_FILE.write(data_path + "\n") if isinstance(data_path, str) and not os.path.exists(data_path): # if the file is gzipped, ungzip it and return the path of the ungzipped @@ -280,33 +220,33 @@ def result_path(self, basename=None, ext=""): Appended file extension. """ - if ext and not ext.startswith("."): - ext = "." + ext - - # Generate the folder name from the calling file name. - path = os.path.abspath(inspect.getfile(self.__class__)) - path = os.path.splitext(path)[0] - sub_path = path.rsplit("iris", 1)[1].split("tests", 1)[1][1:] - - # Generate the file name from the calling function name? - if basename is None: - stack = inspect.stack() - for frame in stack[1:]: - if "test_" in frame[3]: - basename = frame[3].replace("test_", "") - break - filename = basename + ext - - result = os.path.join( - self.get_result_path(""), - sub_path.replace("test_", ""), - self.__class__.__name__.replace("Test_", ""), - filename, - ) - return result - - -def assertCMLApproxData(cubes, reference_filename=None, **kwargs): + return None + # if ext and not ext.startswith("."): + # ext = "." + ext + # # Generate the folder name from the calling file name. + # path = os.path.abspath(inspect.getfile(self.__class__)) + # path = os.path.splitext(path)[0] + # sub_path = path.rsplit("iris", 1)[1].split("tests", 1)[1][1:] + # + # # Generate the file name from the calling function name? + # if basename is None: + # stack = inspect.stack() + # for frame in stack[1:]: + # if "test_" in frame[3]: + # basename = frame[3].replace("test_", "") + # break + # filename = basename + ext + # + # result = os.path.join( + # get_result_path(""), + # sub_path.replace("test_", ""), + # self.__class__.__name__.replace("Test_", ""), + # filename, + # ) + # return result + + +def assert_CML_approx_data(cubes, reference_filename=None, **kwargs): # passes args and kwargs on to approx equal if isinstance(cubes, iris.cube.Cube): cubes = [cubes] @@ -319,11 +259,11 @@ def assertCMLApproxData(cubes, reference_filename=None, **kwargs): if fname[-1].endswith(".cml"): fname[-1] = fname[-1][:-4] fname[-1] += ".data.%d.json" % i - assertDataAlmostEqual(cube.data, fname, **kwargs) - assertCML(cubes, reference_filename, checksum=False) + assert_data_almost_equal(cube.data, fname, **kwargs) + assert_CML(cubes, reference_filename, checksum=False) -def assertCDL(netcdf_filename, reference_filename=None, flags="-h"): +def assert_CDL(netcdf_filename, reference_filename=None, flags="-h"): """Test that the CDL for the given netCDF file matches the contents of the reference file. @@ -386,7 +326,7 @@ def sort_key(line): _check_same(cdl, reference_path, type_comparison_name="CDL") -def assertCML(cubes, reference_filename=None, checksum=True): +def assert_CML(cubes, reference_filename=None, checksum=True): """Test that the CML for the given cubes matches the contents of the reference file. @@ -422,7 +362,7 @@ def assertCML(cubes, reference_filename=None, checksum=True): _check_same(xml, reference_path) -def assertTextFile(source_filename, reference_filename, desc="text file"): +def assert_text_file(source_filename, reference_filename, desc="text file"): """Check if two text files are the same, printing any diffs.""" with open(source_filename) as source_file: source_text = source_file.readlines() @@ -446,7 +386,7 @@ def assertTextFile(source_filename, reference_filename, desc="text file"): assert reference_text == source_text, fail_string -def assertDataAlmostEqual(data, reference_filename, **kwargs): +def assert_data_almost_equal(data, reference_filename, **kwargs): reference_path = get_result_path(reference_filename) if _check_reference_file(reference_path): kwargs.setdefault("err_msg", "Reference file %s" % reference_path) @@ -470,7 +410,7 @@ def assertDataAlmostEqual(data, reference_filename, **kwargs): (data.mean(), data.std(), data.max(), data.min()), dtype=np.float64, ) - assertArrayAllClose(nstats, data_stats, **kwargs) + assert_array_all_close(nstats, data_stats, **kwargs) else: _ensure_folder(reference_path) stats = collections.OrderedDict( @@ -487,7 +427,7 @@ def assertDataAlmostEqual(data, reference_filename, **kwargs): reference_file.write(json.dumps(stats)) -def assertFilesEqual(test_filename, reference_filename): +def assert_files_equal(test_filename, reference_filename): reference_path = get_result_path(reference_filename) if _check_reference_file(reference_path): fmt = "test file {!r} does not match reference {!r}." @@ -499,7 +439,7 @@ def assertFilesEqual(test_filename, reference_filename): shutil.copy(test_filename, reference_path) -def assertString(string, reference_filename=None): +def assert_string(string, reference_filename=None): """Test that `string` matches the contents of the reference file. If the environment variable IRIS_TEST_CREATE_MISSING is @@ -523,8 +463,8 @@ def assertString(string, reference_filename=None): _check_same(string, reference_path, type_comparison_name="Strings") -def assertRepr(obj, reference_filename): - assertString(repr(obj), reference_filename) +def assert_repr(obj, reference_filename): + assert_string(repr(obj), reference_filename) def _check_same(item, reference_path, type_comparison_name="CML"): @@ -540,7 +480,7 @@ def _check_same(item, reference_path, type_comparison_name="CML"): reference_fh.writelines(part.encode("utf-8") for part in item) -def assertXMLElement(obj, reference_filename): +def assert_XML_element(obj, reference_filename): """Calls the xml_element method given obj and asserts the result is the same as the test file.""" doc = xml.dom.minidom.Document() doc.appendChild(obj.xml_element(doc)) @@ -553,7 +493,7 @@ def assertXMLElement(obj, reference_filename): _check_same(pretty_xml, reference_path, type_comparison_name="XML") -def assertArrayEqual(a, b, err_msg=""): +def assert_array_equal(a, b, err_msg=""): np.testing.assert_array_equal(a, b, err_msg=err_msg) @@ -570,11 +510,8 @@ def _recordWarningMatches(expected_regexp=""): @contextlib.contextmanager -def assertLogs(logger=None, level=None, msg_regex=None): - """An extended version of the usual :meth:`unittest.TestCase.assertLogs`, - which also exercises the logger's message formatting. - - Also adds the ``msg_regex`` kwarg: +def assertLogs(caplog, logger=None, level=None, msg_regex=None): + """Also adds the ``msg_regex`` kwarg: If used, check that the result is a single message of the specified level, and that it matches this regex. @@ -587,22 +524,18 @@ def assertLogs(logger=None, level=None, msg_regex=None): there are no formatting errors. """ - # Invoke the standard assertLogs behaviour. - assertlogging_context = super().assertLogs(logger, level) - with assertlogging_context as watcher: - # Run the caller context, as per original method. - yield watcher - # Check for any formatting errors by running all the formatters. - for record in watcher.records: - for handler in assertlogging_context.logger.handlers: - handler.format(record) - - # Check message, if requested. - if msg_regex: - assert len(watcher.records) == 1 - rec = watcher.records[0] - assert level == rec.levelname - assert rec.msg == msg_regex + with caplog.at_level(level, logger.name): + # Check for any formatting errors by running all the formatters. + for record in caplog.records: + for handler in caplog.logger.handlers: + handler.format(record) + + # Check message, if requested. + if msg_regex: + assert len(caplog.records) == 1 + rec = caplog.records[0] + assert level == rec.levelname + assert re.match(msg_regex, rec.msg) @contextlib.contextmanager @@ -613,20 +546,14 @@ def assertNoWarningsRegexp(expected_regexp=""): msg = "Unexpected warning(s) raised, matching '{}' : {!r}." msg = msg.format(expected_regexp, matches) - assert not (matches, msg) - + assert not matches, msg -assertMaskedArrayEqual = staticmethod(assert_masked_array_equal) - -def assertArrayAlmostEqual(a, b, decimal=6): +def assert_array_almost_equal(a, b, decimal=6): np.testing.assert_array_almost_equal(a, b, decimal=decimal) -assertMaskedArrayAlmostEqual = assert_masked_array_almost_equal - - -def assertArrayAllClose(a, b, rtol=1.0e-7, atol=1.0e-8, **kwargs): +def assert_array_all_close(a, b, rtol=1.0e-7, atol=1.0e-8, **kwargs): """Check arrays are equal, within given relative + absolute tolerances. Parameters @@ -659,7 +586,7 @@ def assertArrayAllClose(a, b, rtol=1.0e-7, atol=1.0e-8, **kwargs): if msg is None: # Build a more useful message than np.testing.assert_allclose. msg = ( - '\nARRAY CHECK FAILED "assertArrayAllClose" :' + '\nARRAY CHECK FAILED "assert_array_all_close" :' "\n with shapes={} {}, atol={}, rtol={}" "\n worst at element {} : a={} b={}" "\n absolute error ~{:.3g}, equivalent to rtol ~{:.3e}" @@ -682,62 +609,12 @@ def assertArrayAllClose(a, b, rtol=1.0e-7, atol=1.0e-8, **kwargs): raise AssertionError(msg) -@contextlib.contextmanager -def temp_filename(suffix=""): - filename = iris.util.create_temp_filename(suffix) - try: - yield filename - finally: - os.remove(filename) - - def file_checksum(file_path): """Generate checksum from file.""" with open(file_path, "rb") as in_file: return zlib.crc32(in_file.read()) -def _unique_id(): - """Returns the unique ID for the current assertion. - - The ID is composed of two parts: a unique ID for the current test - (which is itself composed of the module, class, and test names), and - a sequential counter (specific to the current test) that is incremented - on each call. - - For example, calls from a "test_tx" routine followed by a "test_ty" - routine might result in:: - test_plot.TestContourf.test_tx.0 - test_plot.TestContourf.test_tx.1 - test_plot.TestContourf.test_tx.2 - test_plot.TestContourf.test_ty.0 - - """ - # Obtain a consistent ID for the current test. - # NB. unittest.TestCase.id() returns different values depending on - # whether the test has been run explicitly, or via test discovery. - # For example: - # python tests/test_plot.py => '__main__.TestContourf.test_tx' - # ird -t => 'iris.tests.test_plot.TestContourf.test_tx' - bits = id().split(".") - if bits[0] == "__main__": - floc = sys.modules["__main__"].__file__ - path, file_name = os.path.split(os.path.abspath(floc)) - bits[0] = os.path.splitext(file_name)[0] - folder, location = os.path.split(path) - bits = [location] + bits - while location not in ["iris", "gallery_tests"]: - folder, location = os.path.split(folder) - bits = [location] + bits - test_id = ".".join(bits) - - # Derive the sequential assertion ID within the test - assertion_id = _assertion_counts[test_id] - _assertion_counts[test_id] += 1 - - return test_id + "." + str(assertion_id) - - def _check_reference_file(reference_path): reference_exists = os.path.isfile(reference_path) if not (reference_exists or os.environ.get("IRIS_TEST_CREATE_MISSING")): @@ -752,6 +629,7 @@ def _ensure_folder(path): os.makedirs(dir_path) +# todo: need to find equlivalence for `unique_id` in pytest def check_graphic(): """Check the hash of the current matplotlib figure matches the expected image hash for the current graphic test. @@ -762,21 +640,10 @@ def check_graphic(): output directory, and the imagerepo.json file being updated. """ - graphics.check_graphic( - _unique_id(), - _RESULT_PATH, - ) - - -def _remove_testcase_patches(testcase_patches): - """Helper to remove per-testcase patches installed by :meth:`patch`.""" - # Remove all patches made, ignoring errors. - for p in testcase_patches: - p.stop() - # Reset per-test patch control variable. - testcase_patches.clear() + assert False +# todo: relied on unitest functionality, need to find a pytest alternative def patch(*args, **kwargs): """Install a mock.patch, to be removed after the current test. @@ -795,28 +662,10 @@ def patch(*args, **kwargs): self.assertEqual(mock_call.call_args_list, [mock.call(3, 4)]) """ - # Make the new patch and start it. - patch = pytest.mock.patch(*args, **kwargs) - start_result = patch.start() + raise NotImplementedError() - # Create the per-testcases control variable if it does not exist. - # NOTE: this mimics a setUp method, but continues to work when a - # subclass defines its own setUp. - if not hasattr("testcase_patches"): - testcase_patches = {} - # When installing the first patch, schedule remove-all at cleanup. - if not testcase_patches: - _remove_testcase_patches(testcase_patches) - - # Record the new patch and start object for reference. - testcase_patches[patch] = start_result - - # Return patch replacement object. - return start_result - - -def assertArrayShapeStats(result, shape, mean, std_dev, rtol=1e-6): +def assert_array_shape_stats(result, shape, mean, std_dev, rtol=1e-6): """Assert that the result, a cube, has the provided shape and that the mean and standard deviation of the data array are also as provided. Thus build confidence that a cube processing operation, such as a @@ -824,18 +673,12 @@ def assertArrayShapeStats(result, shape, mean, std_dev, rtol=1e-6): """ assert result.shape == shape - assertArrayAllClose(result.data.mean(), mean, rtol=rtol) - assertArrayAllClose(result.data.std(), std_dev, rtol=rtol) - + assert_array_all_close(result.data.mean(), mean, rtol=rtol) + assert_array_all_close(result.data.std(), std_dev, rtol=rtol) -def assertDictEqual(lhs, rhs, msg=None): - """Dictionary Comparison. - This method overrides unittest.TestCase.assertDictEqual (new in Python3.1) - in order to cope with dictionary comparison where the value of a key may - be a numpy array. - - """ +def assert_dict_equal(lhs, rhs, msg=None): + """Dictionary Comparison.""" emsg = f"Provided LHS argument is not a 'Mapping', got {type(lhs)}." assert isinstance(lhs, Mapping), emsg @@ -865,7 +708,7 @@ def assertDictEqual(lhs, rhs, msg=None): ) raise AssertionError(emsg) - assertMaskedArrayEqual(lvalue, rvalue) + assert_masked_array_equal(lvalue, rvalue) elif isinstance(lvalue, np.ndarray) or isinstance(rvalue, np.ndarray): if not isinstance(lvalue, np.ndarray): emsg = ( @@ -883,7 +726,7 @@ def assertDictEqual(lhs, rhs, msg=None): ) raise AssertionError(emsg) - assertArrayEqual(lvalue, rvalue) + assert_array_equal(lvalue, rvalue) else: if lvalue != rvalue: emsg = ( @@ -893,83 +736,44 @@ def assertDictEqual(lhs, rhs, msg=None): raise AssertionError(emsg) -def assertEqualAndKind(value, expected): +def assert_equal_and_kind(value, expected): # Check a value, and also its type 'kind' = float/integer/string. assert value == expected assert np.array(value).dtype.kind == np.array(expected).dtype.kind -class PPTest: - """A mixin class to provide PP-specific utilities to subclasses of tests.IrisTest.""" - - @contextlib.contextmanager - def cube_save_test( - self, - reference_txt_path, - reference_cubes=None, - reference_pp_path=None, - **kwargs, - ): - """A context manager for testing the saving of Cubes to PP files. - - Args: +@contextlib.contextmanager +def cube_save_test( + reference_txt_path, + reference_cubes=None, + reference_pp_path=None, + **kwargs, +): + """A context manager for testing the saving of Cubes to PP files. - * reference_txt_path: - The path of the file containing the textual PP reference data. + Args: - Kwargs: + * reference_txt_path: + The path of the file containing the textual PP reference data. - * reference_cubes: - The cube(s) from which the textual PP reference can be re-built if necessary. - * reference_pp_path: - The location of a PP file from which the textual PP reference can be re-built if necessary. - NB. The "reference_cubes" argument takes precedence over this argument. + Kwargs: - The return value from the context manager is the name of a temporary file - into which the PP data to be tested should be saved. + * reference_cubes: + The cube(s) from which the textual PP reference can be re-built if necessary. + * reference_pp_path: + The location of a PP file from which the textual PP reference can be re-built if necessary. + NB. The "reference_cubes" argument takes precedence over this argument. - Example:: - with self.cube_save_test(reference_txt_path, reference_cubes=cubes) as temp_pp_path: - iris.save(cubes, temp_pp_path) + The return value from the context manager is the name of a temporary file + into which the PP data to be tested should be saved. - """ - # Watch out for a missing reference text file - if not os.path.isfile(reference_txt_path): - if reference_cubes: - temp_pp_path = iris.util.create_temp_filename(".pp") - try: - iris.save(reference_cubes, temp_pp_path, **kwargs) - self._create_reference_txt(reference_txt_path, temp_pp_path) - finally: - os.remove(temp_pp_path) - elif reference_pp_path: - self._create_reference_txt(reference_txt_path, reference_pp_path) - else: - raise ValueError( - "Missing all of reference txt file, cubes, and PP path." - ) + Example:: + with self.cube_save_test(reference_txt_path, reference_cubes=cubes) as temp_pp_path: + iris.save(cubes, temp_pp_path) - temp_pp_path = iris.util.create_temp_filename(".pp") - try: - # This value is returned to the target of the "with" statement's "as" clause. - yield temp_pp_path - - # Load deferred data for all of the fields (but don't do anything with it) - pp_fields = list(iris.fileformats.pp.load(temp_pp_path)) - for pp_field in pp_fields: - pp_field.data - with open(reference_txt_path, "r") as reference_fh: - reference = "".join(reference_fh) - _assert_str_same( - reference + "\n", - str(pp_fields) + "\n", - reference_txt_path, - type_comparison_name="PP files", - ) - finally: - os.remove(temp_pp_path) + """ - def _create_reference_txt(self, txt_path, pp_path): + def _create_reference_txt(txt_path, pp_path): # Load the reference data pp_fields = list(iris.fileformats.pp.load(pp_path)) for pp_field in pp_fields: @@ -985,9 +789,101 @@ def _create_reference_txt(self, txt_path, pp_path): with open(txt_path, "w") as txt_file: txt_file.writelines(str(pp_fields)) + # Watch out for a missing reference text file + if not os.path.isfile(reference_txt_path): + if reference_cubes: + temp_pp_path = iris.util.create_temp_filename(".pp") + try: + iris.save(reference_cubes, temp_pp_path, **kwargs) + _create_reference_txt(reference_txt_path, temp_pp_path) + finally: + os.remove(temp_pp_path) + elif reference_pp_path: + _create_reference_txt(reference_txt_path, reference_pp_path) + else: + raise ValueError("Missing all of reference txt file, cubes, and PP path.") + + temp_pp_path = iris.util.create_temp_filename(".pp") + try: + # This value is returned to the target of the "with" statement's "as" clause. + yield temp_pp_path + + # Load deferred data for all of the fields (but don't do anything with it) + pp_fields = list(iris.fileformats.pp.load(temp_pp_path)) + for pp_field in pp_fields: + pp_field.data + with open(reference_txt_path, "r") as reference_fh: + reference = "".join(reference_fh) + _assert_str_same( + reference + "\n", + str(pp_fields) + "\n", + reference_txt_path, + type_comparison_name="PP files", + ) + finally: + os.remove(temp_pp_path) + + +def skip_data(fn): + """Decorator to choose whether to run tests, based on the availability of + external data. + + Example usage: + @skip_data + class MyDataTests(tests.IrisTest): + ... + + """ + no_data = ( + not iris.config.TEST_DATA_DIR + or not os.path.isdir(iris.config.TEST_DATA_DIR) + or os.environ.get("IRIS_TEST_NO_DATA") + ) + + skip = pytest.skipIf(condition=no_data, reason="Test(s) require external data.") + + return skip(fn) + + +def skip_gdal(fn): + """Decorator to choose whether to run tests, based on the availability of the + GDAL library. + + Example usage: + @skip_gdal + class MyGeoTiffTests(test.IrisTest): + ... + + """ + skip = pytest.skipIf(condition=not GDAL_AVAILABLE, reason="Test requires 'gdal'.") + return skip(fn) + skip_plot = graphics.skip_plot +skip_sample_data = pytest.skipIf( + not SAMPLE_DATA_AVAILABLE, + ('Test(s) require "iris-sample-data", ' "which is not available."), +) + + +skip_nc_time_axis = pytest.skipIf( + not NC_TIME_AXIS_AVAILABLE, + 'Test(s) require "nc_time_axis", which is not available.', +) + + +skip_inet = pytest.skipIf( + not INET_AVAILABLE, + ('Test(s) require an "internet connection", ' "which is not available."), +) + + +skip_stratify = pytest.skipIf( + not STRATIFY_AVAILABLE, + 'Test(s) require "python-stratify", which is not available.', +) + def no_warnings(func): """Provides a decorator to ensure that there are no warnings raised @@ -996,9 +892,9 @@ def no_warnings(func): """ @functools.wraps(func) - def wrapped(self, *args, **kwargs): + def wrapped(*args, **kwargs): with pytest.mock.patch("warnings.warn") as warn: - result = func(self, *args, **kwargs) + result = func(*args, **kwargs) assert 0 == warn.call_count, "Got unexpected warnings.\n{}".format( warn.call_args_list ) From 686d71e4391064f2eaac41c7a0a8b71ebe4ed688 Mon Sep 17 00:00:00 2001 From: Elias Sadek Date: Mon, 11 Mar 2024 15:16:33 +0000 Subject: [PATCH 3/4] converted remaining functions to camel_case --- lib/iris/tests/_shared_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lib/iris/tests/_shared_utils.py b/lib/iris/tests/_shared_utils.py index 7591b61ac8..fee382bd7a 100644 --- a/lib/iris/tests/_shared_utils.py +++ b/lib/iris/tests/_shared_utils.py @@ -498,7 +498,7 @@ def assert_array_equal(a, b, err_msg=""): @contextlib.contextmanager -def _recordWarningMatches(expected_regexp=""): +def _record_warning_matches(expected_regexp=""): # Record warnings raised matching a given expression. matches = [] with warnings.catch_warnings(record=True) as w: @@ -510,7 +510,7 @@ def _recordWarningMatches(expected_regexp=""): @contextlib.contextmanager -def assertLogs(caplog, logger=None, level=None, msg_regex=None): +def assert_logs(caplog, logger=None, level=None, msg_regex=None): """Also adds the ``msg_regex`` kwarg: If used, check that the result is a single message of the specified level, and that it matches this regex. @@ -539,9 +539,9 @@ def assertLogs(caplog, logger=None, level=None, msg_regex=None): @contextlib.contextmanager -def assertNoWarningsRegexp(expected_regexp=""): +def assert_no_warnings_regexp(expected_regexp=""): # Check that no warning matching the given expression is raised. - with _recordWarningMatches(expected_regexp) as matches: + with _record_warning_matches(expected_regexp) as matches: yield msg = "Unexpected warning(s) raised, matching '{}' : {!r}." From 6d16a15d5adf1352e28c610a66fdba869f69f1ae Mon Sep 17 00:00:00 2001 From: Elias Sadek Date: Mon, 11 Mar 2024 15:30:33 +0000 Subject: [PATCH 4/4] actioned majority of review comments --- lib/iris/tests/_shared_utils.py | 44 +++++++++------------------------ 1 file changed, 11 insertions(+), 33 deletions(-) diff --git a/lib/iris/tests/_shared_utils.py b/lib/iris/tests/_shared_utils.py index fee382bd7a..af2a1b03ab 100644 --- a/lib/iris/tests/_shared_utils.py +++ b/lib/iris/tests/_shared_utils.py @@ -221,29 +221,7 @@ def result_path(self, basename=None, ext=""): """ return None - # if ext and not ext.startswith("."): - # ext = "." + ext - # # Generate the folder name from the calling file name. - # path = os.path.abspath(inspect.getfile(self.__class__)) - # path = os.path.splitext(path)[0] - # sub_path = path.rsplit("iris", 1)[1].split("tests", 1)[1][1:] - # - # # Generate the file name from the calling function name? - # if basename is None: - # stack = inspect.stack() - # for frame in stack[1:]: - # if "test_" in frame[3]: - # basename = frame[3].replace("test_", "") - # break - # filename = basename + ext - # - # result = os.path.join( - # get_result_path(""), - # sub_path.replace("test_", ""), - # self.__class__.__name__.replace("Test_", ""), - # filename, - # ) - # return result + # todo: complete this! def assert_CML_approx_data(cubes, reference_filename=None, **kwargs): @@ -511,20 +489,16 @@ def _record_warning_matches(expected_regexp=""): @contextlib.contextmanager def assert_logs(caplog, logger=None, level=None, msg_regex=None): - """Also adds the ``msg_regex`` kwarg: - If used, check that the result is a single message of the specified + """If msg_regex is used, checks that the result is a single message of the specified level, and that it matches this regex. - The inherited version of this method temporarily *replaces* the logger - in order to capture log records generated within the context. - However, in doing so it prevents any messages from being formatted - by the original logger. - This version first calls the original method, but then *also* exercises - the message formatters of all the logger's handlers, just to check that - there are no formatting errors. + Checks that there is at least one message logged at the given parameters, + but then *also* exercises the message formatters of all the logger's handlers, + just to check that there are no formatting errors. """ with caplog.at_level(level, logger.name): + assert len(caplog.records) != 0 # Check for any formatting errors by running all the formatters. for record in caplog.records: for handler in caplog.logger.handlers: @@ -678,7 +652,11 @@ def assert_array_shape_stats(result, shape, mean, std_dev, rtol=1e-6): def assert_dict_equal(lhs, rhs, msg=None): - """Dictionary Comparison.""" + """Dictionary Comparison. + + This allows us to cope with dictionary comparison where the value of a key + may be a numpy array. + """ emsg = f"Provided LHS argument is not a 'Mapping', got {type(lhs)}." assert isinstance(lhs, Mapping), emsg