diff --git a/qiskit/visualization/counts_visualization.py b/qiskit/visualization/counts_visualization.py index c8682a515b8f..98bba2222839 100644 --- a/qiskit/visualization/counts_visualization.py +++ b/qiskit/visualization/counts_visualization.py @@ -14,7 +14,7 @@ Visualization functions for measurement counts. """ -from collections import Counter, OrderedDict +from collections import OrderedDict import functools import numpy as np @@ -64,8 +64,11 @@ def plot_histogram( dict containing the values to represent (ex {'001': 130}) figsize (tuple): Figure size in inches. color (list or str): String or list of strings for histogram bar colors. - number_to_keep (int): The number of terms to plot and rest - is made into a single bar called 'rest'. + number_to_keep (int): The number of terms to plot per dataset. The rest is made into a + single bar called 'rest'. If multiple datasets are given, the ``number_to_keep`` + applies to each dataset individually, which may result in more bars than + ``number_to_keep + 1``. The ``number_to_keep`` applies to the total values, rather than + the x-axis sort. sort (string): Could be `'asc'`, `'desc'`, `'hamming'`, `'value'`, or `'value_desc'`. If set to `'value'` or `'value_desc'` the x axis will be sorted by the maximum probability for each bitstring. @@ -148,7 +151,7 @@ def plot_histogram( if sort in DIST_MEAS: dist = [] for item in labels: - dist.append(DIST_MEAS[sort](item, target_string)) + dist.append(DIST_MEAS[sort](item, target_string) if item != "rest" else 0) labels = [list(x) for x in zip(*sorted(zip(dist, labels), key=lambda pair: pair[0]))][1] elif "value" in sort: @@ -241,6 +244,26 @@ def plot_histogram( return fig.savefig(filename) +def _keep_largest_items(execution, number_to_keep): + """Keep only the largest values in a dictionary, and sum the rest into a new key 'rest'.""" + sorted_counts = sorted(execution.items(), key=lambda p: p[1]) + rest = sum(count for key, count in sorted_counts[:-number_to_keep]) + return dict(sorted_counts[-number_to_keep:], rest=rest) + + +def _unify_labels(data): + """Make all dictionaries in data have the same set of keys, using 0 for missing values.""" + data = tuple(data) + all_labels = set().union(*(execution.keys() for execution in data)) + base = {label: 0 for label in all_labels} + out = [] + for execution in data: + new_execution = base.copy() + new_execution.update(execution) + out.append(new_execution) + return out + + def _plot_histogram_data(data, labels, number_to_keep): """Generate the data needed for plotting counts. @@ -259,22 +282,21 @@ def _plot_histogram_data(data, labels, number_to_keep): experiment. """ labels_dict = OrderedDict() - all_pvalues = [] all_inds = [] + + if isinstance(data, dict): + data = [data] + if number_to_keep is not None: + data = _unify_labels(_keep_largest_items(execution, number_to_keep) for execution in data) + for execution in data: - if number_to_keep is not None: - data_temp = dict(Counter(execution).most_common(number_to_keep)) - data_temp["rest"] = sum(execution.values()) - sum(data_temp.values()) - execution = data_temp values = [] for key in labels: if key not in execution: if number_to_keep is None: labels_dict[key] = 1 values.append(0) - else: - values.append(-1) else: labels_dict[key] = 1 values.append(execution[key]) diff --git a/releasenotes/notes/fix_plot_histogram_number-a0a4a023dfad3c70.yaml b/releasenotes/notes/fix_plot_histogram_number-a0a4a023dfad3c70.yaml new file mode 100644 index 000000000000..76372004ce86 --- /dev/null +++ b/releasenotes/notes/fix_plot_histogram_number-a0a4a023dfad3c70.yaml @@ -0,0 +1,10 @@ +--- +fixes: + - | + Fixed a bug in :func:`~qiskit.visualization.plot_histogram` when the + ``number_to_keep`` argument was smaller that the number of keys. The + following code will not throw errors and will be properly aligned:: + + from qiskit.visualization import plot_histogram + data = {'00': 3, '01': 5, '11': 8, '10': 11} + plot_histogram(data, number_to_keep=2) diff --git a/test/ipynb/mpl/graph/references/histogram_2_sets_with_rest.png b/test/ipynb/mpl/graph/references/histogram_2_sets_with_rest.png new file mode 100644 index 000000000000..d72e46de5429 Binary files /dev/null and b/test/ipynb/mpl/graph/references/histogram_2_sets_with_rest.png differ diff --git a/test/ipynb/mpl/graph/references/histogram_with_rest.png b/test/ipynb/mpl/graph/references/histogram_with_rest.png new file mode 100644 index 000000000000..d91d5f107242 Binary files /dev/null and b/test/ipynb/mpl/graph/references/histogram_with_rest.png differ diff --git a/test/ipynb/mpl/graph/test_graph_matplotlib_drawer.py b/test/ipynb/mpl/graph/test_graph_matplotlib_drawer.py index 2552562dd2cb..879e2c7b21d5 100644 --- a/test/ipynb/mpl/graph/test_graph_matplotlib_drawer.py +++ b/test/ipynb/mpl/graph/test_graph_matplotlib_drawer.py @@ -171,6 +171,19 @@ def test_plot_histogram(self): self.graph_count_drawer(counts, filename="histogram.png") + def test_plot_histogram_with_rest(self): + """test plot_histogram with 2 datasets and number_to_keep""" + data = [{"00": 3, "01": 5, "10": 6, "11": 12}] + self.graph_count_drawer(data, number_to_keep=2, filename="histogram_with_rest.png") + + def test_plot_histogram_2_sets_with_rest(self): + """test plot_histogram with 2 datasets and number_to_keep""" + data = [ + {"00": 3, "01": 5, "10": 6, "11": 12}, + {"00": 5, "01": 7, "10": 6, "11": 12}, + ] + self.graph_count_drawer(data, number_to_keep=2, filename="histogram_2_sets_with_rest.png") + def test_plot_histogram_color(self): """Test histogram with single color""" diff --git a/test/python/visualization/test_plot_histogram.py b/test/python/visualization/test_plot_histogram.py index aca1e02061a2..530f6de613be 100644 --- a/test/python/visualization/test_plot_histogram.py +++ b/test/python/visualization/test_plot_histogram.py @@ -13,15 +13,21 @@ """Tests for plot_histogram.""" import unittest +from io import BytesIO +from collections import Counter + import matplotlib as mpl +from PIL import Image -from qiskit.test import QiskitTestCase from qiskit.tools.visualization import plot_histogram +from qiskit.utils import optionals +from .visualization import QiskitVisualizationTestCase -class TestPlotHistogram(QiskitTestCase): +class TestPlotHistogram(QiskitVisualizationTestCase): """Qiskit plot_histogram tests.""" + @unittest.skipUnless(optionals.HAS_MATPLOTLIB, "matplotlib not available.") def test_different_counts_lengths(self): """Test plotting two different length dists works""" exact_dist = { @@ -107,6 +113,107 @@ def test_different_counts_lengths(self): fig = plot_histogram([raw_dist, exact_dist]) self.assertIsInstance(fig, mpl.figure.Figure) + @unittest.skipUnless(optionals.HAS_MATPLOTLIB, "matplotlib not available.") + def test_with_number_to_keep(self): + """Test plotting using number_to_keep""" + dist = {"00": 3, "01": 5, "11": 8, "10": 11} + fig = plot_histogram(dist, number_to_keep=2) + self.assertIsInstance(fig, mpl.figure.Figure) + + @unittest.skipUnless(optionals.HAS_MATPLOTLIB, "matplotlib not available.") + def test_with_number_to_keep_multiple_executions(self): + """Test plotting using number_to_keep with multiple executions""" + dist = [{"00": 3, "01": 5, "11": 8, "10": 11}, {"00": 3, "01": 7, "10": 11}] + fig = plot_histogram(dist, number_to_keep=2) + self.assertIsInstance(fig, mpl.figure.Figure) + + @unittest.skipUnless(optionals.HAS_MATPLOTLIB, "matplotlib not available.") + def test_with_number_to_keep_multiple_executions_correct_image(self): + """Test plotting using number_to_keep with multiple executions""" + data_noisy = { + "00000": 0.22, + "00001": 0.003, + "00010": 0.005, + "00011": 0.0, + "00100": 0.004, + "00101": 0.001, + "00110": 0.004, + "00111": 0.001, + "01000": 0.005, + "01001": 0.0, + "01010": 0.002, + "01011": 0.0, + "01100": 0.225, + "01101": 0.001, + "01110": 0.003, + "01111": 0.003, + "10000": 0.012, + "10001": 0.002, + "10010": 0.001, + "10011": 0.001, + "10100": 0.247, + "10101": 0.004, + "10110": 0.003, + "10111": 0.001, + "11000": 0.225, + "11001": 0.005, + "11010": 0.002, + "11011": 0.0, + "11100": 0.015, + "11101": 0.004, + "11110": 0.001, + "11111": 0.0, + } + data_ideal = { + "00000": 0.25, + "00001": 0, + "00010": 0, + "00011": 0, + "00100": 0, + "00101": 0, + "00110": 0, + "00111": 0.0, + "01000": 0.0, + "01001": 0, + "01010": 0.0, + "01011": 0.0, + "01100": 0.25, + "01101": 0, + "01110": 0, + "01111": 0, + "10000": 0, + "10001": 0, + "10010": 0.0, + "10011": 0.0, + "10100": 0.25, + "10101": 0, + "10110": 0, + "10111": 0, + "11000": 0.25, + "11001": 0, + "11010": 0, + "11011": 0, + "11100": 0.0, + "11101": 0, + "11110": 0, + "11111": 0.0, + } + data_ref_noisy = dict(Counter(data_noisy).most_common(5)) + data_ref_noisy["rest"] = sum(data_noisy.values()) - sum(data_ref_noisy.values()) + data_ref_ideal = dict(Counter(data_ideal).most_common(4)) # do not add 0 values + data_ref_ideal["rest"] = 0 + figure_ref = plot_histogram([data_ref_ideal, data_ref_noisy]) + figure_truncated = plot_histogram([data_ideal, data_noisy], number_to_keep=5) + with BytesIO() as img_buffer_ref: + figure_ref.savefig(img_buffer_ref, format="png") + img_buffer_ref.seek(0) + with BytesIO() as img_buffer: + figure_truncated.savefig(img_buffer, format="png") + img_buffer.seek(0) + self.assertImagesAreEqual(Image.open(img_buffer_ref), Image.open(img_buffer), 0.2) + mpl.pyplot.close(figure_ref) + mpl.pyplot.close(figure_truncated) + if __name__ == "__main__": unittest.main(verbosity=2)