Skip to content

Commit

Permalink
feat: MatPlot updates (Updated version of #337) (#636)
Browse files Browse the repository at this point in the history
* Copied MatPlot changes from Nulinspiratie fork

* fix: Forgot a call to _get_axes

* refactor: made subplot kwarg 1-based

* refactor: made default_figsize static, remove trailing white spaces

* fix: kwargs are also passed if there are multiple args provided.

* fix: also allow multiple subargs as first arg

* fix: forgot to change default subplot=1 in update_plot
  • Loading branch information
nulinspiratie authored and jenshnielsen committed Jun 21, 2017
1 parent 8904025 commit be08c9a
Showing 1 changed file with 174 additions and 42 deletions.
216 changes: 174 additions & 42 deletions qcodes/plots/qcmatplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
using the nbagg backend and matplotlib
"""
from collections import Mapping
from functools import partial

import matplotlib.pyplot as plt
from matplotlib.transforms import Bbox
import numpy as np
from numpy.ma import masked_invalid, getmask
from collections import Sequence

from .base import BasePlot

Expand All @@ -18,10 +20,13 @@ class MatPlot(BasePlot):
in the constructor, other traces can be added with MatPlot.add()
Args:
*args: shortcut to provide the x/y/z data. See BasePlot.add
*args: Sequence of data to plot. Each element will have its own subplot.
An element can be a single array, or a sequence of arrays. In the
latter case, all arrays will be plotted in the same subplot.
figsize (Tuple[Float, Float]): (width, height) tuple in inches to pass to plt.figure
default (8, 5)
figsize (Tuple[Float, Float]): (width, height) tuple in inches to pass
to plt.figure. If not provided, figsize is determined from
subplots shape
interval: period in seconds between update checks
Expand All @@ -35,35 +40,80 @@ class MatPlot(BasePlot):
**kwargs: passed along to MatPlot.add() to add the first data trace
"""

# Maximum default number of subplot columns. Used to determine shape of
# subplots when not explicitly provided
max_subplot_columns = 3

def __init__(self, *args, figsize=None, interval=1, subplots=None, num=None,
**kwargs):

super().__init__(interval)

if subplots is None:
# Subplots is equal to number of args, or 1 if no args provided
subplots = max(len(args), 1)

self._init_plot(subplots, figsize, num=num)
if args or kwargs:
self.add(*args, **kwargs)

# Add data to plot if passed in args, kwargs are passed to all subplots
for k, arg in enumerate(args):
if isinstance(arg, Sequence):
# Arg consists of multiple elements, add all to same subplot
for subarg in arg:
self[k].add(subarg, **kwargs)
else:
# Arg is single element, add to subplot
self[k].add(arg, **kwargs)

def _init_plot(self, subplots=None, figsize=None, num=None):
if figsize is None:
figsize = (8, 5)
self.tight_layout()

if subplots is None:
subplots = (1, 1)
def __getitem__(self, key):
"""
Subplots can be accessed via indices.
Args:
key: subplot idx
Returns:
Subplot with idx key
"""
return self.subplots[key]

def _init_plot(self, subplots=None, figsize=None, num=None):
if isinstance(subplots, Mapping):
if figsize is None:
figsize = (6, 4)
self.fig, self.subplots = plt.subplots(figsize=figsize, num=num,
**subplots, squeeze=False)
else:
# Format subplots as tuple (nrows, ncols)
if isinstance(subplots, int):
# self.max_subplot_columns defines the limit on how many
# subplots can be in one row. Adjust subplot rows and columns
# accordingly
nrows = int(np.ceil(subplots / self.max_subplot_columns))
ncols = min(subplots, self.max_subplot_columns)
subplots = (nrows, ncols)

if figsize is None:
# Adjust figsize depending on rows and columns in subplots
figsize = self.default_figsize(subplots)

self.fig, self.subplots = plt.subplots(*subplots, num=num,
figsize=figsize, squeeze=False)
figsize=figsize,
squeeze=False)

# squeeze=False ensures that subplots is always a 2D array independent of the number
# of subplots.
# squeeze=False ensures that subplots is always a 2D array independent
# of the number of subplots.
# However the qcodes api assumes that subplots is always a 1D array
# so flatten here
self.subplots = self.subplots.flatten()

for k, subplot in enumerate(self.subplots):
# Include `add` method to subplots, making it easier to add data to
# subplots. Note that subplot kwarg is 1-based, to adhere to
# Matplotlib standards
subplot.add = partial(self.add, subplot=k+1)

self.title = self.fig.suptitle('')

def clear(self, subplots=None, figsize=None):
Expand All @@ -75,28 +125,35 @@ def clear(self, subplots=None, figsize=None):
self.fig.clf()
self._init_plot(subplots, figsize, num=self.fig.number)

def add_to_plot(self, **kwargs):
def add_to_plot(self, use_offset=False, **kwargs):
"""
adds one trace to this MatPlot.
kwargs: with the following exceptions (mostly the data!), these are
passed directly to the matplotlib plotting routine.
`subplot`: the 1-based axes number to append to (default 1)
if kwargs include `z`, we will draw a heatmap (ax.pcolormesh):
`x`, `y`, and `z` are passed as positional args to pcolormesh
without `z` we draw a scatter/lines plot (ax.plot):
`x`, `y`, and `fmt` (if present) are passed as positional args
Args:
use_offset (bool, Optional): Whether or not ticks can have an offset
kwargs: with the following exceptions (mostly the data!), these are
passed directly to the matplotlib plotting routine.
`subplot`: the 1-based axes number to append to (default 1)
if kwargs include `z`, we will draw a heatmap (ax.pcolormesh):
`x`, `y`, and `z` are passed as positional args to
pcolormesh
without `z` we draw a scatter/lines plot (ax.plot):
`x`, `y`, and `fmt` (if present) are passed as positional
args
"""
# TODO some way to specify overlaid axes?
ax = self._get_axes(kwargs)
# Note that there is a conversion from subplot kwarg, which is
# 1-based, to subplot idx, which is 0-based.
ax = self[kwargs.get('subplot', 1) - 1]
if 'z' in kwargs:
plot_object = self._draw_pcolormesh(ax, **kwargs)
else:
plot_object = self._draw_plot(ax, **kwargs)

# Specify if axes ticks can have offset or not
ax.ticklabel_format(useOffset=use_offset)

self._update_labels(ax, kwargs)
prev_default_title = self.get_default_title()

Expand All @@ -109,9 +166,6 @@ def add_to_plot(self, **kwargs):
# in case the user has updated title, don't change it anymore
self.title.set_text(self.get_default_title())

def _get_axes(self, config):
return self.subplots[config.get('subplot', 1) - 1]

def _update_labels(self, ax, config):
for axletter in ("x", "y"):
if axletter+'label' in config:
Expand Down Expand Up @@ -146,6 +200,21 @@ def _update_labels(self, ax, config):
axsetter = getattr(ax, "set_{}label".format(axletter))
axsetter("{} ({})".format(label, unit))

@staticmethod
def default_figsize(subplots):
"""
Provides default figsize for given subplots.
Args:
subplots (Tuple[Int, Int]): shape (nrows, ncols) of subplots
Returns:
Figsize (Tuple[Float, Float])): (width, height) of default figsize
for given subplot shape
"""
if not isinstance(subplots, tuple):
raise TypeError('Subplots {} must be a tuple'.format(subplots))
return (min(3 + 3 * subplots[1], 12), 1 + 3 * subplots[0])

def update_plot(self):
"""
update the plot. The DataSets themselves have already been updated
Expand All @@ -164,7 +233,7 @@ def update_plot(self):
if plot_object:
plot_object.remove()

ax = self._get_axes(config)
ax = self[config.get('subplot', 1) - 1]
plot_object = self._draw_pcolormesh(ax, **config)
trace['plot_object'] = plot_object

Expand Down Expand Up @@ -202,11 +271,12 @@ def _draw_plot(self, ax, y, x=None, fmt=None, subplot=1,
yunit=None,
zunit=None,
**kwargs):
# NOTE(alexj)stripping out subplot because which subplot we're in is already
# described by ax, and it's not a kwarg to matplotlib's ax.plot. But I
# didn't want to strip it out of kwargs earlier because it should stay
# part of trace['config'].
# NOTE(alexj)stripping out subplot because which subplot we're in is
# already described by ax, and it's not a kwarg to matplotlib's ax.plot.
# But I didn't want to strip it out of kwargs earlier because it should
# stay part of trace['config'].
args = [arg for arg in [x, y, fmt] if arg is not None]

line, = ax.plot(*args, **kwargs)
return line

Expand All @@ -217,21 +287,71 @@ def _draw_pcolormesh(self, ax, z, x=None, y=None, subplot=1,
xunit=None,
yunit=None,
zunit=None,
nticks=None,
**kwargs):
# NOTE(alexj)stripping out subplot because which subplot we're in is already
# described by ax, and it's not a kwarg to matplotlib's ax.plot. But I
# didn't want to strip it out of kwargs earlier because it should stay
# part of trace['config'].
args = [masked_invalid(arg) for arg in [x, y, z]
if arg is not None]

for arg in args:
if np.all(getmask(arg)):
# if any entire array is masked, don't draw at all
# there's nothing to draw, and anyway it throws a warning
return False
args_masked = [masked_invalid(arg) for arg in [x, y, z]
if arg is not None]

if np.any([np.all(getmask(arg)) for arg in args_masked]):
# if the z array is masked, don't draw at all
# there's nothing to draw, and anyway it throws a warning
# pcolormesh does not accept masked x and y axes, so we do not need
# to check for them.
return False

if x is not None and y is not None:
# If x and y are provided, modify the arrays such that they
# correspond to grid corners instead of grid centers.
# This is to ensure that pcolormesh centers correctly and
# does not ignore edge points.
args = []
for k, arr in enumerate(args_masked[:-1]):
# If a two-dimensional array is provided, only consider the
# first row/column, depending on the axis
if arr.ndim > 1:
arr = arr[0] if k == 0 else arr[:,0]

if np.ma.is_masked(arr[1]):
# Only the first element is not nan, in this case pad with
# a value, and separate their values by 1
arr_pad = np.pad(arr, (1, 0), mode='symmetric')
arr_pad[:2] += [-0.5, 0.5]
else:
# Add padding on both sides equal to endpoints
arr_pad = np.pad(arr, (1, 1), mode='symmetric')
# Add differences to edgepoints (may be nan)
arr_pad[0] += arr_pad[1] - arr_pad[2]
arr_pad[-1] += arr_pad[-2] - arr_pad[-3]

diff = np.ma.diff(arr_pad) / 2
# Insert value at beginning and end of diff to ensure same
# length
diff = np.insert(diff, 0, diff[0])

arr_pad += diff
# Ignore final value
arr_pad = arr_pad[:-1]
args.append(masked_invalid(arr_pad))
args.append(args_masked[-1])
else:
# Only the masked value of z is used as a mask
args = args_masked[-1:]

pc = ax.pcolormesh(*args, **kwargs)

# Set x and y limits if arrays are provided
if x is not None and y is not None:
ax.set_xlim(np.nanmin(args[0]), np.nanmax(args[0]))
ax.set_ylim(np.nanmin(args[1]), np.nanmax(args[1]))

# Specify preferred number of ticks with labels
if nticks and ax.get_xscale() != 'log' and ax.get_yscale != 'log':
ax.locator_params(nbins=nticks)

if getattr(ax, 'qcodes_colorbar', None):
# update_normal doesn't seem to work...
ax.qcodes_colorbar.update_bruteforce(pc)
Expand All @@ -255,6 +375,11 @@ def _draw_pcolormesh(self, ax, z, x=None, y=None, subplot=1,
label = "{} ({})".format(zlabel, zunit)
ax.qcodes_colorbar.set_label(label)

# Scale colors if z has elements
cmin = np.nanmin(args_masked[-1])
cmax = np.nanmax(args_masked[-1])
ax.qcodes_colorbar.set_clim(cmin, cmax)

return pc

def save(self, filename=None):
Expand All @@ -269,3 +394,10 @@ def save(self, filename=None):
default = "{}.png".format(self.get_default_title())
filename = filename or default
self.fig.savefig(filename)

def tight_layout(self):
"""
Perform a tight layout on the figure. A bit of additional spacing at
the top is also added for the title.
"""
self.fig.tight_layout(rect=[0, 0, 1, 0.95])

0 comments on commit be08c9a

Please sign in to comment.