diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py index 2a4c67036d6..4f0232236f8 100644 --- a/xarray/plot/facetgrid.py +++ b/xarray/plot/facetgrid.py @@ -7,8 +7,8 @@ from ..core.formatting import format_item from .utils import ( - _determine_cmap_params, _infer_xy_labels, import_matplotlib_pyplot, - label_from_attrs) + _infer_xy_labels, _process_cmap_cbar_kwargs, + import_matplotlib_pyplot, label_from_attrs) # Overrides axes.labelsize, xtick.major.size, ytick.major.size # from mpl.rcParams @@ -219,32 +219,13 @@ def map_dataarray(self, func, x, y, **kwargs): """ - cmapkw = kwargs.get('cmap') - colorskw = kwargs.get('colors') - cbar_kwargs = kwargs.pop('cbar_kwargs', {}) - cbar_kwargs = {} if cbar_kwargs is None else dict(cbar_kwargs) - if kwargs.get('cbar_ax', None) is not None: raise ValueError('cbar_ax not supported by FacetGrid.') - # colors is mutually exclusive with cmap - if cmapkw and colorskw: - raise ValueError("Can't specify both cmap and colors.") - - # These should be consistent with xarray.plot._plot2d - cmap_kwargs = {'plot_data': self.data.values, - # MPL default - 'levels': 7 if 'contour' in func.__name__ else None, - 'filled': func.__name__ != 'contour', - } - - cmap_args = getfullargspec(_determine_cmap_params).args - cmap_kwargs.update((a, kwargs[a]) for a in cmap_args if a in kwargs) + cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs( + func, kwargs, self.data.values) - cmap_params = _determine_cmap_params(**cmap_kwargs) - - if colorskw is not None: - cmap_params['cmap'] = None + self._cmap_extend = cmap_params.get('extend') # Order is important func_kwargs = kwargs.copy() @@ -260,7 +241,7 @@ def map_dataarray(self, func, x, y, **kwargs): # None is the sentinel value if d is not None: subset = self.data.loc[d] - mappable = func(subset, x, y, ax=ax, **func_kwargs) + mappable = func(subset, x=x, y=y, ax=ax, **func_kwargs) self._mappables.append(mappable) self._cmap_extend = cmap_params.get('extend') @@ -271,36 +252,24 @@ def map_dataarray(self, func, x, y, **kwargs): return self - def map_dataarray_line(self, x=None, y=None, hue=None, **kwargs): - """ - Apply a line plot to a 2d facet subset of the data. - - Parameters - ---------- - x, y, hue: string - dimension names for the axes and hues of each facet - - Returns - ------- - self : FacetGrid object - - """ - from .plot import line, _infer_line_data + def map_dataarray_line(self, func, x, y, **kwargs): + from .plot import _infer_line_data add_legend = kwargs.pop('add_legend', True) kwargs['add_legend'] = False + func_kwargs = kwargs.copy() + func_kwargs['_labels'] = False for d, ax in zip(self.name_dicts.flat, self.axes.flat): # None is the sentinel value if d is not None: subset = self.data.loc[d] - mappable = line(subset, x=x, y=y, hue=hue, - ax=ax, _labels=False, - **kwargs) + mappable = func(subset, x=x, y=y, ax=ax, **func_kwargs) self._mappables.append(mappable) + _, _, hueplt, xlabel, ylabel, huelabel = _infer_line_data( darray=self.data.loc[self.name_dicts.flat[0]], - x=x, y=y, hue=hue) + x=x, y=y, hue=func_kwargs['hue']) self._hue_var = hueplt self._hue_label = huelabel @@ -520,3 +489,33 @@ def map(self, func, *args, **kwargs): self._finalize_grid(*args[:2]) return self + + +def _easy_facetgrid(data, plotfunc, kind, x=None, y=None, row=None, + col=None, col_wrap=None, sharex=True, sharey=True, + aspect=None, size=None, subplot_kws=None, **kwargs): + """ + Convenience method to call xarray.plot.FacetGrid from 2d plotting methods + + kwargs are the arguments to 2d plotting method + """ + ax = kwargs.pop('ax', None) + figsize = kwargs.pop('figsize', None) + if ax is not None: + raise ValueError("Can't use axes when making faceted plots.") + if aspect is None: + aspect = 1 + if size is None: + size = 3 + elif figsize is not None: + raise ValueError('cannot provide both `figsize` and `size` arguments') + + g = FacetGrid(data=data, col=col, row=row, col_wrap=col_wrap, + sharex=sharex, sharey=sharey, figsize=figsize, + aspect=aspect, size=size, subplot_kws=subplot_kws) + + if kind == 'line': + return g.map_dataarray_line(plotfunc, x, y, **kwargs) + + if kind == 'dataarray': + return g.map_dataarray(plotfunc, x, y, **kwargs) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 9178dd8f031..5b60f8d73a1 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -6,97 +6,96 @@ DataArray.plot._____ """ import functools -import warnings -from datetime import datetime import numpy as np import pandas as pd from xarray.core.common import contains_cftime_datetimes -from .facetgrid import FacetGrid +from .facetgrid import _easy_facetgrid from .utils import ( - ROBUST_PERCENTILE, _determine_cmap_params, _infer_xy_labels, + _add_colorbar, _ensure_plottable, _infer_interval_breaks, _infer_xy_labels, _interval_to_double_bound_points, _interval_to_mid_points, - _resolve_intervals_2dplot, _valid_other_type, get_axis, - import_matplotlib_pyplot, label_from_attrs) + _process_cmap_cbar_kwargs, _rescale_imshow_rgb, _resolve_intervals_2dplot, + _update_axes, _valid_other_type, get_axis, import_matplotlib_pyplot, + label_from_attrs) -def _valid_numpy_subdtype(x, numpy_types): - """ - Is any dtype from numpy_types superior to the dtype of x? - """ - # If any of the types given in numpy_types is understood as numpy.generic, - # all possible x will be considered valid. This is probably unwanted. - for t in numpy_types: - assert not np.issubdtype(np.generic, t) +def _infer_line_data(darray, x, y, hue): + error_msg = ('must be either None or one of ({0:s})' + .format(', '.join([repr(dd) for dd in darray.dims]))) + ndims = len(darray.dims) - return any(np.issubdtype(x.dtype, t) for t in numpy_types) + if x is not None and x not in darray.dims and x not in darray.coords: + raise ValueError('x ' + error_msg) + if y is not None and y not in darray.dims and y not in darray.coords: + raise ValueError('y ' + error_msg) -def _ensure_plottable(*args): - """ - Raise exception if there is anything in args that can't be plotted on an - axis by matplotlib. - """ - numpy_types = [np.floating, np.integer, np.timedelta64, np.datetime64] - other_types = [datetime] + if x is not None and y is not None: + raise ValueError('You cannot specify both x and y kwargs' + 'for line plots.') - for x in args: - if not (_valid_numpy_subdtype(np.array(x), numpy_types) - or _valid_other_type(np.array(x), other_types)): - raise TypeError('Plotting requires coordinates to be numeric ' - 'or dates of type np.datetime64 or ' - 'datetime.datetime or pd.Interval.') + if ndims == 1: + huename = None + hueplt = None + huelabel = '' + if x is not None: + xplt = darray[x] + yplt = darray -def _easy_facetgrid(darray, plotfunc, x, y, row=None, col=None, - col_wrap=None, sharex=True, sharey=True, aspect=None, - size=None, subplot_kws=None, **kwargs): - """ - Convenience method to call xarray.plot.FacetGrid from 2d plotting methods + elif y is not None: + xplt = darray + yplt = darray[y] - kwargs are the arguments to 2d plotting method - """ - ax = kwargs.pop('ax', None) - figsize = kwargs.pop('figsize', None) - if ax is not None: - raise ValueError("Can't use axes when making faceted plots.") - if aspect is None: - aspect = 1 - if size is None: - size = 3 - elif figsize is not None: - raise ValueError('cannot provide both `figsize` and `size` arguments') - - g = FacetGrid(data=darray, col=col, row=row, col_wrap=col_wrap, - sharex=sharex, sharey=sharey, figsize=figsize, - aspect=aspect, size=size, subplot_kws=subplot_kws) - return g.map_dataarray(plotfunc, x, y, **kwargs) - - -def _line_facetgrid(darray, row=None, col=None, hue=None, - col_wrap=None, sharex=True, sharey=True, aspect=None, - size=None, subplot_kws=None, **kwargs): - """ - Convenience method to call xarray.plot.FacetGrid for line plots - kwargs are the arguments to pyplot.plot() - """ - ax = kwargs.pop('ax', None) - figsize = kwargs.pop('figsize', None) - if ax is not None: - raise ValueError("Can't use axes when making faceted plots.") - if aspect is None: - aspect = 1 - if size is None: - size = 3 - elif figsize is not None: - raise ValueError('cannot provide both `figsize` and `size` arguments') + else: # Both x & y are None + dim = darray.dims[0] + xplt = darray[dim] + yplt = darray + + else: + if x is None and y is None and hue is None: + raise ValueError('For 2D inputs, please' + 'specify either hue, x or y.') + + if y is None: + xname, huename = _infer_xy_labels(darray=darray, x=x, y=hue) + xplt = darray[xname] + if xplt.ndim > 1: + if huename in darray.dims: + otherindex = 1 if darray.dims.index(huename) == 0 else 0 + otherdim = darray.dims[otherindex] + yplt = darray.transpose(otherdim, huename) + xplt = xplt.transpose(otherdim, huename) + else: + raise ValueError('For 2D inputs, hue must be a dimension' + + ' i.e. one of ' + repr(darray.dims)) + + else: + yplt = darray.transpose(xname, huename) + + else: + yname, huename = _infer_xy_labels(darray=darray, x=y, y=hue) + yplt = darray[yname] + if yplt.ndim > 1: + if huename in darray.dims: + otherindex = 1 if darray.dims.index(huename) == 0 else 0 + xplt = darray.transpose(otherdim, huename) + else: + raise ValueError('For 2D inputs, hue must be a dimension' + + ' i.e. one of ' + repr(darray.dims)) - g = FacetGrid(data=darray, col=col, row=row, col_wrap=col_wrap, - sharex=sharex, sharey=sharey, figsize=figsize, - aspect=aspect, size=size, subplot_kws=subplot_kws) - return g.map_dataarray_line(hue=hue, **kwargs) + else: + xplt = darray.transpose(yname, huename) + + huelabel = label_from_attrs(darray[huename]) + hueplt = darray[huename] + + xlabel = label_from_attrs(xplt) + ylabel = label_from_attrs(yplt) + + return xplt, yplt, hueplt, xlabel, ylabel, huelabel def plot(darray, row=None, col=None, col_wrap=None, ax=None, hue=None, @@ -184,83 +183,6 @@ def plot(darray, row=None, col=None, col_wrap=None, ax=None, hue=None, return plotfunc(darray, **kwargs) -def _infer_line_data(darray, x, y, hue): - error_msg = ('must be either None or one of ({0:s})' - .format(', '.join([repr(dd) for dd in darray.dims]))) - ndims = len(darray.dims) - - if x is not None and x not in darray.dims and x not in darray.coords: - raise ValueError('x ' + error_msg) - - if y is not None and y not in darray.dims and y not in darray.coords: - raise ValueError('y ' + error_msg) - - if x is not None and y is not None: - raise ValueError('You cannot specify both x and y kwargs' - 'for line plots.') - - if ndims == 1: - huename = None - hueplt = None - huelabel = '' - - if x is not None: - xplt = darray[x] - yplt = darray - - elif y is not None: - xplt = darray - yplt = darray[y] - - else: # Both x & y are None - dim = darray.dims[0] - xplt = darray[dim] - yplt = darray - - else: - if x is None and y is None and hue is None: - raise ValueError('For 2D inputs, please' - 'specify either hue, x or y.') - - if y is None: - xname, huename = _infer_xy_labels(darray=darray, x=x, y=hue) - xplt = darray[xname] - if xplt.ndim > 1: - if huename in darray.dims: - otherindex = 1 if darray.dims.index(huename) == 0 else 0 - otherdim = darray.dims[otherindex] - yplt = darray.transpose(otherdim, huename) - xplt = xplt.transpose(otherdim, huename) - else: - raise ValueError('For 2D inputs, hue must be a dimension' - + ' i.e. one of ' + repr(darray.dims)) - - else: - yplt = darray.transpose(xname, huename) - - else: - yname, huename = _infer_xy_labels(darray=darray, x=y, y=hue) - yplt = darray[yname] - if yplt.ndim > 1: - if huename in darray.dims: - otherindex = 1 if darray.dims.index(huename) == 0 else 0 - xplt = darray.transpose(otherdim, huename) - else: - raise ValueError('For 2D inputs, hue must be a dimension' - + ' i.e. one of ' + repr(darray.dims)) - - else: - xplt = darray.transpose(yname, huename) - - huelabel = label_from_attrs(darray[huename]) - hueplt = darray[huename] - - xlabel = label_from_attrs(xplt) - ylabel = label_from_attrs(yplt) - - return xplt, yplt, hueplt, xlabel, ylabel, huelabel - - # This function signature should not change so that it can use # matplotlib format strings def line(darray, *args, **kwargs): @@ -316,7 +238,8 @@ def line(darray, *args, **kwargs): if row or col: allargs = locals().copy() allargs.update(allargs.pop('kwargs')) - return _line_facetgrid(**allargs) + allargs.pop('darray') + return _easy_facetgrid(darray, line, kind='line', **allargs) ndims = len(darray.dims) if ndims > 2: @@ -496,48 +419,6 @@ def hist(darray, figsize=None, size=None, aspect=None, ax=None, **kwargs): return primitive -def _update_axes(ax, xincrease, yincrease, - xscale=None, yscale=None, - xticks=None, yticks=None, - xlim=None, ylim=None): - """ - Update axes with provided parameters - """ - if xincrease is None: - pass - elif xincrease and ax.xaxis_inverted(): - ax.invert_xaxis() - elif not xincrease and not ax.xaxis_inverted(): - ax.invert_xaxis() - - if yincrease is None: - pass - elif yincrease and ax.yaxis_inverted(): - ax.invert_yaxis() - elif not yincrease and not ax.yaxis_inverted(): - ax.invert_yaxis() - - # The default xscale, yscale needs to be None. - # If we set a scale it resets the axes formatters, - # This means that set_xscale('linear') on a datetime axis - # will remove the date labels. So only set the scale when explicitly - # asked to. https://github.com/matplotlib/matplotlib/issues/8740 - if xscale is not None: - ax.set_xscale(xscale) - if yscale is not None: - ax.set_yscale(yscale) - - if xticks is not None: - ax.set_xticks(xticks) - if yticks is not None: - ax.set_yticks(yticks) - - if xlim is not None: - ax.set_xlim(xlim) - if ylim is not None: - ax.set_ylim(ylim) - - # MUST run before any 2d plotting functions are defined since # _plot2d decorator adds them as methods here. class _PlotMethods(object): @@ -565,44 +446,6 @@ def step(self, *args, **kwargs): return step(self._da, *args, **kwargs) -def _rescale_imshow_rgb(darray, vmin, vmax, robust): - assert robust or vmin is not None or vmax is not None - # TODO: remove when min numpy version is bumped to 1.13 - # There's a cyclic dependency via DataArray, so we can't import from - # xarray.ufuncs in global scope. - from xarray.ufuncs import maximum, minimum - - # Calculate vmin and vmax automatically for `robust=True` - if robust: - if vmax is None: - vmax = np.nanpercentile(darray, 100 - ROBUST_PERCENTILE) - if vmin is None: - vmin = np.nanpercentile(darray, ROBUST_PERCENTILE) - # If not robust and one bound is None, calculate the default other bound - # and check that an interval between them exists. - elif vmax is None: - vmax = 255 if np.issubdtype(darray.dtype, np.integer) else 1 - if vmax < vmin: - raise ValueError( - 'vmin=%r is less than the default vmax (%r) - you must supply ' - 'a vmax > vmin in this case.' % (vmin, vmax)) - elif vmin is None: - vmin = 0 - if vmin > vmax: - raise ValueError( - 'vmax=%r is less than the default vmin (0) - you must supply ' - 'a vmin < vmax in this case.' % vmax) - # Scale interval [vmin .. vmax] to [0 .. 1], with darray as 64-bit float - # to avoid precision loss, integer over/underflow, etc with extreme inputs. - # After scaling, downcast to 32-bit float. This substantially reduces - # memory usage after we hand `darray` off to matplotlib. - darray = ((darray.astype('f8') - vmin) / (vmax - vmin)).astype('f4') - with warnings.catch_warnings(): - warnings.filterwarnings('ignore', 'xarray.ufuncs', - PendingDeprecationWarning) - return minimum(maximum(darray, 0), 1) - - def _plot2d(plotfunc): """ Decorator for common 2d plotting logic @@ -745,38 +588,23 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None, allargs = locals().copy() allargs.pop('imshow_rgb') allargs.update(allargs.pop('kwargs')) - + allargs.pop('darray') # Need the decorated plotting function allargs['plotfunc'] = globals()[plotfunc.__name__] - - return _easy_facetgrid(**allargs) + return _easy_facetgrid(darray, kind='dataarray', **allargs) plt = import_matplotlib_pyplot() - # colors is mutually exclusive with cmap - if cmap and colors: - raise ValueError("Can't specify both cmap and colors.") - # colors is only valid when levels is supplied or the plot is of type - # contour or contourf - if colors and (('contour' not in plotfunc.__name__) and (not levels)): - raise ValueError("Can only specify colors with contour or levels") - # we should not be getting a list of colors in cmap anymore - # is there a better way to do this test? - if isinstance(cmap, (list, tuple)): - warnings.warn("Specifying a list of colors in cmap is deprecated. " - "Use colors keyword instead.", - DeprecationWarning, stacklevel=3) - rgb = kwargs.pop('rgb', None) - xlab, ylab = _infer_xy_labels( - darray=darray, x=x, y=y, imshow=imshow_rgb, rgb=rgb) - if rgb is not None and plotfunc.__name__ != 'imshow': raise ValueError('The "rgb" keyword is only valid for imshow()') elif rgb is not None and not imshow_rgb: raise ValueError('The "rgb" keyword is only valid for imshow()' 'with a three-dimensional array (per facet)') + xlab, ylab = _infer_xy_labels( + darray=darray, x=x, y=y, imshow=imshow_rgb, rgb=rgb) + # better to pass the ndarrays directly to plotting functions xval = darray[xlab].values yval = darray[ylab].values @@ -810,22 +638,8 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None, _ensure_plottable(xplt, yplt) - if 'contour' in plotfunc.__name__ and levels is None: - levels = 7 # this is the matplotlib default - - cmap_kwargs = {'plot_data': zval.data, - 'vmin': vmin, - 'vmax': vmax, - 'cmap': colors if colors else cmap, - 'center': center, - 'robust': robust, - 'extend': extend, - 'levels': levels, - 'filled': plotfunc.__name__ != 'contour', - 'norm': norm, - } - - cmap_params = _determine_cmap_params(**cmap_kwargs) + cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs( + plotfunc, locals(), zval.data) if 'contour' in plotfunc.__name__: # extend is a keyword argument only for contour and contourf, but @@ -861,16 +675,12 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None, ax.set_title(darray._title_for_slice()) if add_colorbar: - cbar_kwargs = {} if cbar_kwargs is None else dict(cbar_kwargs) - cbar_kwargs.setdefault('extend', cmap_params['extend']) - if cbar_ax is None: - cbar_kwargs.setdefault('ax', ax) - else: - cbar_kwargs.setdefault('cax', cbar_ax) - cbar = plt.colorbar(primitive, **cbar_kwargs) if add_labels and 'label' not in cbar_kwargs: - cbar.set_label(label_from_attrs(darray)) - elif cbar_ax is not None or cbar_kwargs is not None: + cbar_kwargs['label'] = label_from_attrs(darray) + cbar = _add_colorbar(primitive, ax, cbar_ax, cbar_kwargs, + cmap_params) + + elif (cbar_ax is not None or cbar_kwargs): # inform the user about keywords which aren't used raise ValueError("cbar_ax and cbar_kwargs can't be used with " "add_colorbar=False.") @@ -1020,54 +830,6 @@ def contourf(x, y, z, ax, **kwargs): return primitive -def _is_monotonic(coord, axis=0): - """ - >>> _is_monotonic(np.array([0, 1, 2])) - True - >>> _is_monotonic(np.array([2, 1, 0])) - True - >>> _is_monotonic(np.array([0, 2, 1])) - False - """ - if coord.shape[axis] < 3: - return True - else: - n = coord.shape[axis] - delta_pos = (coord.take(np.arange(1, n), axis=axis) >= - coord.take(np.arange(0, n - 1), axis=axis)) - delta_neg = (coord.take(np.arange(1, n), axis=axis) <= - coord.take(np.arange(0, n - 1), axis=axis)) - return np.all(delta_pos) or np.all(delta_neg) - - -def _infer_interval_breaks(coord, axis=0, check_monotonic=False): - """ - >>> _infer_interval_breaks(np.arange(5)) - array([-0.5, 0.5, 1.5, 2.5, 3.5, 4.5]) - >>> _infer_interval_breaks([[0, 1], [3, 4]], axis=1) - array([[-0.5, 0.5, 1.5], - [ 2.5, 3.5, 4.5]]) - """ - coord = np.asarray(coord) - - if check_monotonic and not _is_monotonic(coord, axis=axis): - raise ValueError("The input coordinate is not sorted in increasing " - "order along axis %d. This can lead to unexpected " - "results. Consider calling the `sortby` method on " - "the input DataArray. To plot data with categorical " - "axes, consider using the `heatmap` function from " - "the `seaborn` statistical plotting library." % axis) - - deltas = 0.5 * np.diff(coord, axis=axis) - if deltas.size == 0: - deltas = np.array(0.0) - first = np.take(coord, [0], axis=axis) - np.take(deltas, [0], axis=axis) - last = np.take(coord, [-1], axis=axis) + np.take(deltas, [-1], axis=axis) - trim_last = tuple(slice(None, -1) if n == axis else slice(None) - for n in range(coord.ndim)) - return np.concatenate([first, coord[trim_last] + deltas, last], axis=axis) - - @_plot2d def pcolormesh(x, y, z, ax, infer_intervals=None, **kwargs): """ diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index a42fbc7aba6..6d812fbc2bc 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -1,10 +1,13 @@ import itertools import textwrap import warnings +from datetime import datetime import numpy as np import pandas as pd +from inspect import getfullargspec + from ..core.options import OPTIONS from ..core.utils import is_scalar @@ -447,3 +450,232 @@ def _valid_other_type(x, types): Do all elements of x have a type from types? """ return all(any(isinstance(el, t) for t in types) for el in np.ravel(x)) + + +def _valid_numpy_subdtype(x, numpy_types): + """ + Is any dtype from numpy_types superior to the dtype of x? + """ + # If any of the types given in numpy_types is understood as numpy.generic, + # all possible x will be considered valid. This is probably unwanted. + for t in numpy_types: + assert not np.issubdtype(np.generic, t) + + return any(np.issubdtype(x.dtype, t) for t in numpy_types) + + +def _ensure_plottable(*args): + """ + Raise exception if there is anything in args that can't be plotted on an + axis by matplotlib. + """ + numpy_types = [np.floating, np.integer, np.timedelta64, np.datetime64] + other_types = [datetime] + + for x in args: + if not (_valid_numpy_subdtype(np.array(x), numpy_types) + or _valid_other_type(np.array(x), other_types)): + raise TypeError('Plotting requires coordinates to be numeric ' + 'or dates of type np.datetime64 or ' + 'datetime.datetime or pd.Interval.') + + +def _ensure_numeric(arr): + numpy_types = [np.floating, np.integer] + return _valid_numpy_subdtype(arr, numpy_types) + + +def _add_colorbar(primitive, ax, cbar_ax, cbar_kwargs, cmap_params): + plt = import_matplotlib_pyplot() + cbar_kwargs.setdefault('extend', cmap_params['extend']) + if cbar_ax is None: + cbar_kwargs.setdefault('ax', ax) + else: + cbar_kwargs.setdefault('cax', cbar_ax) + + cbar = plt.colorbar(primitive, **cbar_kwargs) + + return cbar + + +def _rescale_imshow_rgb(darray, vmin, vmax, robust): + assert robust or vmin is not None or vmax is not None + # TODO: remove when min numpy version is bumped to 1.13 + # There's a cyclic dependency via DataArray, so we can't import from + # xarray.ufuncs in global scope. + from xarray.ufuncs import maximum, minimum + + # Calculate vmin and vmax automatically for `robust=True` + if robust: + if vmax is None: + vmax = np.nanpercentile(darray, 100 - ROBUST_PERCENTILE) + if vmin is None: + vmin = np.nanpercentile(darray, ROBUST_PERCENTILE) + # If not robust and one bound is None, calculate the default other bound + # and check that an interval between them exists. + elif vmax is None: + vmax = 255 if np.issubdtype(darray.dtype, np.integer) else 1 + if vmax < vmin: + raise ValueError( + 'vmin=%r is less than the default vmax (%r) - you must supply ' + 'a vmax > vmin in this case.' % (vmin, vmax)) + elif vmin is None: + vmin = 0 + if vmin > vmax: + raise ValueError( + 'vmax=%r is less than the default vmin (0) - you must supply ' + 'a vmin < vmax in this case.' % vmax) + # Scale interval [vmin .. vmax] to [0 .. 1], with darray as 64-bit float + # to avoid precision loss, integer over/underflow, etc with extreme inputs. + # After scaling, downcast to 32-bit float. This substantially reduces + # memory usage after we hand `darray` off to matplotlib. + darray = ((darray.astype('f8') - vmin) / (vmax - vmin)).astype('f4') + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', 'xarray.ufuncs', + PendingDeprecationWarning) + return minimum(maximum(darray, 0), 1) + + +def _update_axes(ax, xincrease, yincrease, + xscale=None, yscale=None, + xticks=None, yticks=None, + xlim=None, ylim=None): + """ + Update axes with provided parameters + """ + if xincrease is None: + pass + elif xincrease and ax.xaxis_inverted(): + ax.invert_xaxis() + elif not xincrease and not ax.xaxis_inverted(): + ax.invert_xaxis() + + if yincrease is None: + pass + elif yincrease and ax.yaxis_inverted(): + ax.invert_yaxis() + elif not yincrease and not ax.yaxis_inverted(): + ax.invert_yaxis() + + # The default xscale, yscale needs to be None. + # If we set a scale it resets the axes formatters, + # This means that set_xscale('linear') on a datetime axis + # will remove the date labels. So only set the scale when explicitly + # asked to. https://github.com/matplotlib/matplotlib/issues/8740 + if xscale is not None: + ax.set_xscale(xscale) + if yscale is not None: + ax.set_yscale(yscale) + + if xticks is not None: + ax.set_xticks(xticks) + if yticks is not None: + ax.set_yticks(yticks) + + if xlim is not None: + ax.set_xlim(xlim) + if ylim is not None: + ax.set_ylim(ylim) + + +def _is_monotonic(coord, axis=0): + """ + >>> _is_monotonic(np.array([0, 1, 2])) + True + >>> _is_monotonic(np.array([2, 1, 0])) + True + >>> _is_monotonic(np.array([0, 2, 1])) + False + """ + if coord.shape[axis] < 3: + return True + else: + n = coord.shape[axis] + delta_pos = (coord.take(np.arange(1, n), axis=axis) >= + coord.take(np.arange(0, n - 1), axis=axis)) + delta_neg = (coord.take(np.arange(1, n), axis=axis) <= + coord.take(np.arange(0, n - 1), axis=axis)) + return np.all(delta_pos) or np.all(delta_neg) + + +def _infer_interval_breaks(coord, axis=0, check_monotonic=False): + """ + >>> _infer_interval_breaks(np.arange(5)) + array([-0.5, 0.5, 1.5, 2.5, 3.5, 4.5]) + >>> _infer_interval_breaks([[0, 1], [3, 4]], axis=1) + array([[-0.5, 0.5, 1.5], + [ 2.5, 3.5, 4.5]]) + """ + coord = np.asarray(coord) + + if check_monotonic and not _is_monotonic(coord, axis=axis): + raise ValueError("The input coordinate is not sorted in increasing " + "order along axis %d. This can lead to unexpected " + "results. Consider calling the `sortby` method on " + "the input DataArray. To plot data with categorical " + "axes, consider using the `heatmap` function from " + "the `seaborn` statistical plotting library." % axis) + + deltas = 0.5 * np.diff(coord, axis=axis) + if deltas.size == 0: + deltas = np.array(0.0) + first = np.take(coord, [0], axis=axis) - np.take(deltas, [0], axis=axis) + last = np.take(coord, [-1], axis=axis) + np.take(deltas, [-1], axis=axis) + trim_last = tuple(slice(None, -1) if n == axis else slice(None) + for n in range(coord.ndim)) + return np.concatenate([first, coord[trim_last] + deltas, last], axis=axis) + + +def _process_cmap_cbar_kwargs(func, kwargs, data): + """ + Parameters + ========== + func : plotting function + kwargs : dict, + Dictionary with arguments that need to be parsed + data : ndarray, + Data values + + Returns + ======= + cmap_params + + cbar_kwargs + """ + + cmap = kwargs.pop('cmap', None) + colors = kwargs.pop('colors', None) + + cbar_kwargs = kwargs.pop('cbar_kwargs', {}) + cbar_kwargs = {} if cbar_kwargs is None else dict(cbar_kwargs) + + levels = kwargs.pop('levels', None) + if 'contour' in func.__name__ and levels is None: + levels = 7 # this is the matplotlib default + + # colors is mutually exclusive with cmap + if cmap and colors: + raise ValueError("Can't specify both cmap and colors.") + + # colors is only valid when levels is supplied or the plot is of type + # contour or contourf + if colors and (('contour' not in func.__name__) and (not levels)): + raise ValueError("Can only specify colors with contour or levels") + + # we should not be getting a list of colors in cmap anymore + # is there a better way to do this test? + if isinstance(cmap, (list, tuple)): + warnings.warn("Specifying a list of colors in cmap is deprecated. " + "Use colors keyword instead.", + DeprecationWarning, stacklevel=3) + + cmap_kwargs = {'plot_data': data, + 'levels': levels, + 'cmap': colors if colors else cmap, + 'filled': func.__name__ != 'contour'} + + cmap_args = getfullargspec(_determine_cmap_params).args + cmap_kwargs.update((a, kwargs[a]) for a in cmap_args if a in kwargs) + cmap_params = _determine_cmap_params(**cmap_kwargs) + + return cmap_params, cbar_kwargs