Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 114 additions & 0 deletions iris_grib/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,17 @@
import os
import os.path

import numpy as np

try:
from iris.tests import IrisTest_nometa as IrisTest
except ImportError:
from iris.tests import IrisTest

from iris.tests import main, skip_data, get_data_path

from iris_grib.message import GribMessage


#: Basepath for iris-grib test results.
_RESULT_PATH = os.path.join(os.path.dirname(__file__), 'results')
Expand Down Expand Up @@ -92,3 +96,113 @@ def get_testdata_path(relative_path):
if not isinstance(relative_path, str):
relative_path = os.path.join(*relative_path)
return os.path.abspath(os.path.join(_TESTDATA_PATH, relative_path))


class TestGribMessage(IrisGribTest):
def assertGribMessageContents(self, filename, contents):
"""
Evaluate whether all messages in a GRIB2 file contain the provided
contents.

* filename (string)
The path on disk of an existing GRIB file

* contents
An iterable of GRIB message keys and expected values.

"""
messages = GribMessage.messages_from_filename(filename)
for message in messages:
for element in contents:
section, key, val = element
self.assertEqual(message.sections[section][key], val)

def assertGribMessageDifference(
self, filename1, filename2, diffs, skip_keys=(), skip_sections=()
):
"""
Evaluate that the two messages only differ in the ways specified.

* filename[0|1] (string)
The path on disk of existing GRIB files

* diffs
An dictionary of GRIB message keys and expected diff values:
{key: (m1val, m2val),...} .

* skip_keys
An iterable of key names to ignore during comparison.

* skip_sections
An iterable of section numbers to ignore during comparison.

"""
messages1 = list(GribMessage.messages_from_filename(filename1))
messages2 = list(GribMessage.messages_from_filename(filename2))
self.assertEqual(len(messages1), len(messages2))
for m1, m2 in zip(messages1, messages2):
m1_sect = set(m1.sections.keys())
m2_sect = set(m2.sections.keys())

for missing_section in m1_sect ^ m2_sect:
what = (
"introduced" if missing_section in m1_sect else "removed"
)
# Assert that an introduced section is in the diffs.
self.assertIn(
missing_section,
skip_sections,
msg="Section {} {}".format(missing_section, what),
)

for section in m1_sect & m2_sect:
# For each section, check that the differences are
# known diffs.
m1_keys = set(m1.sections[section]._keys)
m2_keys = set(m2.sections[section]._keys)

difference = m1_keys ^ m2_keys
unexpected_differences = difference - set(skip_keys)
if unexpected_differences:
self.fail(
"There were keys in section {} which \n"
"weren't in both messages and which weren't "
"skipped.\n{}"
"".format(section, ", ".join(unexpected_differences))
)

keys_to_compare = m1_keys & m2_keys - set(skip_keys)

for key in keys_to_compare:
m1_value = m1.sections[section][key]
m2_value = m2.sections[section][key]
msg = "{} {} != {}"
if key not in diffs:
# We have a key which we expect to be the same for
# both messages.
if isinstance(m1_value, np.ndarray):
# A large tolerance appears to be required for
# gribapi 1.12, but not for 1.14.
self.assertArrayAlmostEqual(
m1_value, m2_value, decimal=2
)
else:
self.assertEqual(
m1_value,
m2_value,
msg=msg.format(key, m1_value, m2_value),
)
else:
# We have a key which we expect to be different
# for each message.
self.assertEqual(
m1_value,
diffs[key][0],
msg=msg.format(key, m1_value, diffs[key][0]),
)

self.assertEqual(
m2_value,
diffs[key][1],
msg=msg.format(key, m2_value, diffs[key][1]),
)
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,10 @@
from iris.fileformats.pp import EARTH_RADIUS as UM_DEFAULT_EARTH_RADIUS
from iris.util import is_regular

from iris.tests import TestGribMessage

from iris_grib.grib_phenom_translation import GRIBCode


class TestGDT5(TestGribMessage):
class TestGDT5(tests.TestGribMessage):
def test_save_load(self):
# Load sample UKV data (variable-resolution rotated grid).
path = tests.get_data_path(("PP", "ukV1", "ukVpmslont.pp"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,12 @@
import iris.coords
import iris.coord_systems

from iris.tests import TestGribMessage
import iris.tests.stock as stock

from iris_grib.grib_phenom_translation import GRIBCode


class TestPDT11(TestGribMessage):
class TestPDT11(tests.TestGribMessage):
def test_perturbation(self):
path = tests.get_data_path(
("NetCDF", "global", "xyt", "SMALL_hires_wind_u_for_ipcc4.nc")
Expand Down
4 changes: 1 addition & 3 deletions iris_grib/tests/integration/save_rules/test_grib_save.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,11 @@
import iris.exceptions
import iris.util

from iris.tests import TestGribMessage

import gribapi
from iris_grib._load_convert import _MDI as MDI


class TestLoadSave(TestGribMessage):
class TestLoadSave(tests.TestGribMessage):
def setUp(self):
self.skip_keys = []

Expand Down