Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
a7d973a
plot_histogram fails when given number_to_keep parameter #7461
iuliazidaru Jan 5, 2022
05a9767
Fix style
iuliazidaru Jan 5, 2022
877322f
fix review comments
iuliazidaru Jan 6, 2022
ea6321a
remove temporary the release notes
iuliazidaru Jan 6, 2022
44f5daa
add release notes
iuliazidaru Jan 6, 2022
94ae6d2
Reformat release notes
iuliazidaru Jan 6, 2022
8ab7d91
Merge branch 'main' into issue7461
iuliazidaru Jan 7, 2022
5a1abe7
Temp - remove code from release note
iuliazidaru Jan 10, 2022
88dd944
Merge branch 'issue7461' of https://github.com/iuliazidaru/qiskit-ter…
iuliazidaru Jan 10, 2022
b997c9a
Add code example in release note.
iuliazidaru Jan 11, 2022
4dd9589
Add code example in release note.
iuliazidaru Jan 11, 2022
2705650
Add code example in release note.
iuliazidaru Jan 11, 2022
fcf7025
add test for multiple executions display
iuliazidaru Jan 12, 2022
30583fe
change dictionary from OderedDict to defacultdict
iuliazidaru Jan 19, 2022
993d686
Merge branch 'main' into issue7461
iuliazidaru Feb 23, 2022
f636ac5
Merge branch 'main' into issue7461
iuliazidaru Mar 1, 2022
dbbc4f2
small fixes
iuliazidaru Mar 4, 2022
b8fb075
Correct use of optional tests
jakelishman Mar 4, 2022
6337b6c
Reword release note
jakelishman Mar 4, 2022
25545c8
Improve documentation of `number_to_keep`
jakelishman Mar 4, 2022
847c6a6
Fix crash on distance measures
jakelishman Mar 4, 2022
79b7e01
Merge branch 'main' into issue7461
iuliazidaru Mar 4, 2022
e4ade96
Revert if.
iuliazidaru Mar 4, 2022
bfef156
Merge branch 'main' into issue7461
iuliazidaru Mar 11, 2022
235a9d6
refactor code & add test
iuliazidaru Mar 11, 2022
d54af7b
Merge branch 'main' into issue7461
mergify[bot] Jun 7, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 33 additions & 11 deletions qiskit/visualization/counts_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
Visualization functions for measurement counts.
"""

from collections import Counter, OrderedDict
from collections import OrderedDict
import functools
import numpy as np

Expand Down Expand Up @@ -64,8 +64,11 @@ def plot_histogram(
dict containing the values to represent (ex {'001': 130})
figsize (tuple): Figure size in inches.
color (list or str): String or list of strings for histogram bar colors.
number_to_keep (int): The number of terms to plot and rest
is made into a single bar called 'rest'.
number_to_keep (int): The number of terms to plot per dataset. The rest is made into a
single bar called 'rest'. If multiple datasets are given, the ``number_to_keep``
applies to each dataset individually, which may result in more bars than
``number_to_keep + 1``. The ``number_to_keep`` applies to the total values, rather than
the x-axis sort.
sort (string): Could be `'asc'`, `'desc'`, `'hamming'`, `'value'`, or
`'value_desc'`. If set to `'value'` or `'value_desc'` the x axis
will be sorted by the maximum probability for each bitstring.
Expand Down Expand Up @@ -148,7 +151,7 @@ def plot_histogram(
if sort in DIST_MEAS:
dist = []
for item in labels:
dist.append(DIST_MEAS[sort](item, target_string))
dist.append(DIST_MEAS[sort](item, target_string) if item != "rest" else 0)

labels = [list(x) for x in zip(*sorted(zip(dist, labels), key=lambda pair: pair[0]))][1]
elif "value" in sort:
Expand Down Expand Up @@ -241,6 +244,26 @@ def plot_histogram(
return fig.savefig(filename)


def _keep_largest_items(execution, number_to_keep):
"""Keep only the largest values in a dictionary, and sum the rest into a new key 'rest'."""
sorted_counts = sorted(execution.items(), key=lambda p: p[1])
rest = sum(count for key, count in sorted_counts[:-number_to_keep])
return dict(sorted_counts[-number_to_keep:], rest=rest)


def _unify_labels(data):
"""Make all dictionaries in data have the same set of keys, using 0 for missing values."""
data = tuple(data)
all_labels = set().union(*(execution.keys() for execution in data))
base = {label: 0 for label in all_labels}
out = []
for execution in data:
new_execution = base.copy()
new_execution.update(execution)
out.append(new_execution)
return out


def _plot_histogram_data(data, labels, number_to_keep):
"""Generate the data needed for plotting counts.

Expand All @@ -259,22 +282,21 @@ def _plot_histogram_data(data, labels, number_to_keep):
experiment.
"""
labels_dict = OrderedDict()

all_pvalues = []
all_inds = []

if isinstance(data, dict):
data = [data]
if number_to_keep is not None:
data = _unify_labels(_keep_largest_items(execution, number_to_keep) for execution in data)

for execution in data:
if number_to_keep is not None:
data_temp = dict(Counter(execution).most_common(number_to_keep))
data_temp["rest"] = sum(execution.values()) - sum(data_temp.values())
execution = data_temp
values = []
for key in labels:
if key not in execution:
if number_to_keep is None:
labels_dict[key] = 1
values.append(0)
else:
values.append(-1)
else:
labels_dict[key] = 1
values.append(execution[key])
Expand Down
10 changes: 10 additions & 0 deletions releasenotes/notes/fix_plot_histogram_number-a0a4a023dfad3c70.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
---
fixes:
- |
Fixed a bug in :func:`~qiskit.visualization.plot_histogram` when the
``number_to_keep`` argument was smaller that the number of keys. The
following code will not throw errors and will be properly aligned::

from qiskit.visualization import plot_histogram
data = {'00': 3, '01': 5, '11': 8, '10': 11}
plot_histogram(data, number_to_keep=2)
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
13 changes: 13 additions & 0 deletions test/ipynb/mpl/graph/test_graph_matplotlib_drawer.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,19 @@ def test_plot_histogram(self):

self.graph_count_drawer(counts, filename="histogram.png")

def test_plot_histogram_with_rest(self):
"""test plot_histogram with 2 datasets and number_to_keep"""
data = [{"00": 3, "01": 5, "10": 6, "11": 12}]
self.graph_count_drawer(data, number_to_keep=2, filename="histogram_with_rest.png")

def test_plot_histogram_2_sets_with_rest(self):
"""test plot_histogram with 2 datasets and number_to_keep"""
data = [
{"00": 3, "01": 5, "10": 6, "11": 12},
{"00": 5, "01": 7, "10": 6, "11": 12},
]
self.graph_count_drawer(data, number_to_keep=2, filename="histogram_2_sets_with_rest.png")

def test_plot_histogram_color(self):
"""Test histogram with single color"""

Expand Down
111 changes: 109 additions & 2 deletions test/python/visualization/test_plot_histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,21 @@
"""Tests for plot_histogram."""

import unittest
from io import BytesIO
from collections import Counter

import matplotlib as mpl
from PIL import Image

from qiskit.test import QiskitTestCase
from qiskit.tools.visualization import plot_histogram
from qiskit.utils import optionals
from .visualization import QiskitVisualizationTestCase


class TestPlotHistogram(QiskitTestCase):
class TestPlotHistogram(QiskitVisualizationTestCase):
"""Qiskit plot_histogram tests."""

@unittest.skipUnless(optionals.HAS_MATPLOTLIB, "matplotlib not available.")
def test_different_counts_lengths(self):
"""Test plotting two different length dists works"""
exact_dist = {
Expand Down Expand Up @@ -107,6 +113,107 @@ def test_different_counts_lengths(self):
fig = plot_histogram([raw_dist, exact_dist])
self.assertIsInstance(fig, mpl.figure.Figure)

@unittest.skipUnless(optionals.HAS_MATPLOTLIB, "matplotlib not available.")
def test_with_number_to_keep(self):
"""Test plotting using number_to_keep"""
dist = {"00": 3, "01": 5, "11": 8, "10": 11}
fig = plot_histogram(dist, number_to_keep=2)
self.assertIsInstance(fig, mpl.figure.Figure)

@unittest.skipUnless(optionals.HAS_MATPLOTLIB, "matplotlib not available.")
def test_with_number_to_keep_multiple_executions(self):
"""Test plotting using number_to_keep with multiple executions"""
dist = [{"00": 3, "01": 5, "11": 8, "10": 11}, {"00": 3, "01": 7, "10": 11}]
fig = plot_histogram(dist, number_to_keep=2)
self.assertIsInstance(fig, mpl.figure.Figure)

@unittest.skipUnless(optionals.HAS_MATPLOTLIB, "matplotlib not available.")
def test_with_number_to_keep_multiple_executions_correct_image(self):
"""Test plotting using number_to_keep with multiple executions"""
data_noisy = {
"00000": 0.22,
"00001": 0.003,
"00010": 0.005,
"00011": 0.0,
"00100": 0.004,
"00101": 0.001,
"00110": 0.004,
"00111": 0.001,
"01000": 0.005,
"01001": 0.0,
"01010": 0.002,
"01011": 0.0,
"01100": 0.225,
"01101": 0.001,
"01110": 0.003,
"01111": 0.003,
"10000": 0.012,
"10001": 0.002,
"10010": 0.001,
"10011": 0.001,
"10100": 0.247,
"10101": 0.004,
"10110": 0.003,
"10111": 0.001,
"11000": 0.225,
"11001": 0.005,
"11010": 0.002,
"11011": 0.0,
"11100": 0.015,
"11101": 0.004,
"11110": 0.001,
"11111": 0.0,
}
data_ideal = {
"00000": 0.25,
"00001": 0,
"00010": 0,
"00011": 0,
"00100": 0,
"00101": 0,
"00110": 0,
"00111": 0.0,
"01000": 0.0,
"01001": 0,
"01010": 0.0,
"01011": 0.0,
"01100": 0.25,
"01101": 0,
"01110": 0,
"01111": 0,
"10000": 0,
"10001": 0,
"10010": 0.0,
"10011": 0.0,
"10100": 0.25,
"10101": 0,
"10110": 0,
"10111": 0,
"11000": 0.25,
"11001": 0,
"11010": 0,
"11011": 0,
"11100": 0.0,
"11101": 0,
"11110": 0,
"11111": 0.0,
}
data_ref_noisy = dict(Counter(data_noisy).most_common(5))
data_ref_noisy["rest"] = sum(data_noisy.values()) - sum(data_ref_noisy.values())
data_ref_ideal = dict(Counter(data_ideal).most_common(4)) # do not add 0 values
data_ref_ideal["rest"] = 0
figure_ref = plot_histogram([data_ref_ideal, data_ref_noisy])
figure_truncated = plot_histogram([data_ideal, data_noisy], number_to_keep=5)
with BytesIO() as img_buffer_ref:
figure_ref.savefig(img_buffer_ref, format="png")
img_buffer_ref.seek(0)
Comment on lines +207 to +209
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Someone from the qiskit team will need to check how saving reference images is supposed to work (is it supposed to be saved to a file?). I didn't figure that out on my previous pass at roughly this issue.

Looks good otherwise, took me a second to figure out what this test is actually doing but looks good now.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah.. that.

We use snapshot testing for this kind of things (introduced in #4544). However, it seems like test/python/visualization/test_plot_histogram.py is not using that method and fixing that is beyond this PR. Do you think it make sense to have this on hold until that is fixed?

with BytesIO() as img_buffer:
figure_truncated.savefig(img_buffer, format="png")
img_buffer.seek(0)
self.assertImagesAreEqual(Image.open(img_buffer_ref), Image.open(img_buffer), 0.2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about this 0.2 threshold. I don't have a good reference for how big this should be.

mpl.pyplot.close(figure_ref)
mpl.pyplot.close(figure_truncated)


if __name__ == "__main__":
unittest.main(verbosity=2)