diff --git a/lib/iris/tests/__init__.py b/lib/iris/tests/__init__.py index bf2d9fef23..64d9f0ef46 100644 --- a/lib/iris/tests/__init__.py +++ b/lib/iris/tests/__init__.py @@ -539,26 +539,37 @@ def assertRaisesRegexp(self, *args, **kwargs): *args, **kwargs) @contextlib.contextmanager - def assertGivesWarning(self, expected_regexp='', expect_warning=True): - # Check that a warning is raised matching a given expression, or that - # no warning matching the given expression is raised. + def _recordWarningMatches(self, expected_regexp=''): + # Record warnings raised matching a given expression. + matches = [] with warnings.catch_warnings(record=True) as w: warnings.simplefilter('always') - yield + yield matches messages = [str(warning.message) for warning in w] expr = re.compile(expected_regexp) - matches = [message for message in messages if expr.search(message)] - warning_raised = any(matches) - if expect_warning: - if not warning_raised: - msg = "Warning matching '{}' not raised." - msg = msg.format(expected_regexp) - self.assertEqual(expect_warning, warning_raised, msg) - else: - if warning_raised: - msg = "Unexpected warning(s) raised, matching '{}' : {!r}." - msg = msg.format(expected_regexp, matches) - self.assertEqual(expect_warning, warning_raised, msg) + matches.extend(message for message in messages + if expr.search(message)) + + @contextlib.contextmanager + def assertWarnsRegexp(self, expected_regexp=''): + # Check that a warning is raised matching a given expression. + with self._recordWarningMatches(expected_regexp) as matches: + yield + + msg = "Warning matching '{}' not raised." + msg = msg.format(expected_regexp) + self.assertTrue(matches, msg) + + + @contextlib.contextmanager + def assertNoWarningsRegexp(self, expected_regexp=''): + # Check that no warning matching the given expression is raised. + with self._recordWarningMatches(expected_regexp) as matches: + yield + + msg = "Unexpected warning(s) raised, matching '{}' : {!r}." + msg = msg.format(expected_regexp, matches) + self.assertFalse(matches, msg) def _assertMaskedArray(self, assertion, a, b, strict, **kwargs): # Define helper function to extract unmasked values as a 1d diff --git a/lib/iris/tests/unit/fileformats/netcdf/test_Saver.py b/lib/iris/tests/unit/fileformats/netcdf/test_Saver.py index fcf70c1b3c..449fac03d1 100644 --- a/lib/iris/tests/unit/fileformats/netcdf/test_Saver.py +++ b/lib/iris/tests/unit/fileformats/netcdf/test_Saver.py @@ -410,7 +410,7 @@ def test_contains_fill_value_passed(self): # Test that a warning is raised if the data contains the fill value. cube = self._make_cube('>f4') fill_value = 1 - with self.assertGivesWarning( + with self.assertWarnsRegexp( 'contains unmasked data points equal to the fill-value'): with self._netCDF_var(cube, fill_value=fill_value): pass @@ -420,7 +420,7 @@ def test_contains_fill_value_byte(self): # when it is of a byte type. cube = self._make_cube('>i1') fill_value = 1 - with self.assertGivesWarning( + with self.assertWarnsRegexp( 'contains unmasked data points equal to the fill-value'): with self._netCDF_var(cube, fill_value=fill_value): pass @@ -430,7 +430,7 @@ def test_contains_default_fill_value(self): # value if no fill_value argument is supplied. cube = self._make_cube('>f4') cube.data[0, 0] = nc.default_fillvals['f4'] - with self.assertGivesWarning( + with self.assertWarnsRegexp( 'contains unmasked data points equal to the fill-value'): with self._netCDF_var(cube): pass @@ -440,7 +440,7 @@ def test_contains_default_fill_value_byte(self): # value if no fill_value argument is supplied when the data is of a # byte type. cube = self._make_cube('>i1') - with self.assertGivesWarning(r'\(fill\|mask\)', expect_warning=False): + with self.assertNoWarningsRegexp(r'\(fill\|mask\)'): with self._netCDF_var(cube): pass @@ -449,7 +449,7 @@ def test_contains_masked_fill_value(self): # a masked point. fill_value = 1 cube = self._make_cube('>f4', masked_value=fill_value) - with self.assertGivesWarning(r'\(fill\|mask\)', expect_warning=False): + with self.assertNoWarningsRegexp(r'\(fill\|mask\)'): with self._netCDF_var(cube, fill_value=fill_value): pass @@ -457,7 +457,7 @@ def test_masked_byte_default_fill_value(self): # Test that a warning is raised when saving masked byte data with no # fill value supplied. cube = self._make_cube('>i1', masked_value=1) - with self.assertGivesWarning(r'\(fill\|mask\)', expect_warning=False): + with self.assertNoWarningsRegexp(r'\(fill\|mask\)'): with self._netCDF_var(cube): pass @@ -466,7 +466,7 @@ def test_masked_byte_fill_value_passed(self): # fill value supplied if the the data does not contain the fill_value. fill_value = 100 cube = self._make_cube('>i1', masked_value=2) - with self.assertGivesWarning(r'\(fill\|mask\)', expect_warning=False): + with self.assertNoWarningsRegexp(r'\(fill\|mask\)'): with self._netCDF_var(cube, fill_value=fill_value): pass diff --git a/lib/iris/tests/unit/fileformats/pp/test_PPField.py b/lib/iris/tests/unit/fileformats/pp/test_PPField.py index 774f7feda0..1a402a2631 100644 --- a/lib/iris/tests/unit/fileformats/pp/test_PPField.py +++ b/lib/iris/tests/unit/fileformats/pp/test_PPField.py @@ -108,7 +108,7 @@ def field_checksum(data): data_64 = np.linspace(0, 1, num=10, endpoint=False).reshape(2, 5) checksum_32 = field_checksum(data_64.astype('>f4')) msg = 'Downcasting array precision from float64 to float32 for save.' - with self.assertGivesWarning(msg): + with self.assertWarnsRegexp(msg): checksum_64 = field_checksum(data_64.astype('>f8')) self.assertEqual(checksum_32, checksum_64) @@ -119,7 +119,7 @@ def test_masked_mdi_value_warning(self): # Make float32 data, as float64 default produces an extra warning. field.data = np.ma.masked_array([1., field.bmdi, 3.], dtype=np.float32) msg = 'PPField data contains unmasked points' - with self.assertGivesWarning(msg): + with self.assertWarnsRegexp(msg): with self.temp_filename('.pp') as temp_filename: with open(temp_filename, 'wb') as pp_file: field.save(pp_file) @@ -131,7 +131,7 @@ def test_unmasked_mdi_value_warning(self): # Make float32 data, as float64 default produces an extra warning. field.data = np.array([1., field.bmdi, 3.], dtype=np.float32) msg = 'PPField data contains unmasked points' - with self.assertGivesWarning(msg): + with self.assertWarnsRegexp(msg): with self.temp_filename('.pp') as temp_filename: with open(temp_filename, 'wb') as pp_file: field.save(pp_file) @@ -146,7 +146,7 @@ def test_mdi_masked_value_nowarning(self): # Set underlying data value at masked point to BMDI value. field.data.data[1] = field.bmdi self.assertArrayAllClose(field.data.data[1], field.bmdi) - with self.assertGivesWarning(r'\(mask\|fill\)', expect_warning=False): + with self.assertNoWarningsRegexp(r'\(mask\|fill\)'): with self.temp_filename('.pp') as temp_filename: with open(temp_filename, 'wb') as pp_file: field.save(pp_file) diff --git a/lib/iris/tests/unit/test_Future.py b/lib/iris/tests/unit/test_Future.py index 0cbdf90a76..321b85ebd2 100644 --- a/lib/iris/tests/unit/test_Future.py +++ b/lib/iris/tests/unit/test_Future.py @@ -40,7 +40,7 @@ def test_valid_clip_latitudes(self): future = Future() new_value = not future.clip_latitudes msg = "'Future' property 'clip_latitudes' is deprecated" - with self.assertGivesWarning(msg): + with self.assertWarnsRegexp(msg): future.clip_latitudes = new_value self.assertEqual(future.clip_latitudes, new_value)