Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 23 additions & 4 deletions qiskit/visualization/counts_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,8 @@ def _plot_histogram_data(data, labels, number_to_keep):
"""Generate the data needed for plotting counts.

Parameters:
data (list or dict): This is either a list of dictionaries or a single
dict containing the values to represent (ex {'001': 130})
data (list of dict): This is a list of dictionaries (or a list of a single
dict) containing the values to represent (ex {'001': 130})
labels (list): The list of bitstring labels for the plot.
number_to_keep (int): The number of terms to plot and rest
is made into a single bar called 'rest'.
Expand All @@ -264,10 +264,27 @@ def _plot_histogram_data(data, labels, number_to_keep):

all_pvalues = []
all_inds = []
# pre-calculate the items that will be kept, using the normalized dictionaries
# equally weights each dictionary
if number_to_keep is not None:
item_counter = Counter()
for execution in data:
values_total = sum(execution.values())
normalized_execution = {k: v / values_total for k, v in execution.items()}
item_counter.update(normalized_execution)
keys_to_keep = list(map(lambda elem: elem[0], item_counter.most_common(number_to_keep))) + [
"rest"
]
labels = keys_to_keep

for execution in data:
if number_to_keep is not None:
data_temp = dict(Counter(execution).most_common(number_to_keep))
data_temp["rest"] = sum(execution.values()) - sum(data_temp.values())
data_temp = {k: execution.get(k, 0) for k in labels}
# data_temp = filter(lambda elem: elem[0] in keys_to_keep, execution.items()))
if "rest" in execution.keys():
data_temp["rest"] = execution["rest"]
else:
data_temp["rest"] = sum(execution.values()) - sum(data_temp.values())
execution = data_temp
values = []
for key in labels:
Expand All @@ -281,6 +298,8 @@ def _plot_histogram_data(data, labels, number_to_keep):
labels_dict[key] = 1
values.append(execution[key])
values = np.array(values, dtype=float)
if number_to_keep is not None:
assert len(values) <= number_to_keep + 1
pvalues = values / sum(values)
all_pvalues.append(pvalues)
numelem = len(values)
Expand Down
111 changes: 107 additions & 4 deletions test/python/visualization/test_plot_histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,25 @@
"""Tests for plot_histogram."""

import unittest
import matplotlib as mpl

from qiskit.test import QiskitTestCase
from qiskit.tools.visualization import plot_histogram
from io import BytesIO
from PIL import Image

from qiskit.tools.visualization import plot_histogram, HAS_MATPLOTLIB

class TestPlotHistogram(QiskitTestCase):
# from .visualization import path_to_diagram_reference
from .visualization import QiskitVisualizationTestCase


if HAS_MATPLOTLIB:
import matplotlib as mpl
import matplotlib.pyplot as plt


class TestPlotHistogram(QiskitVisualizationTestCase):
"""Qiskit plot_histogram tests."""

@unittest.skipIf(not HAS_MATPLOTLIB, "matplotlib not available.")
def test_different_counts_lengths(self):
"""Test plotting two different length dists works"""
exact_dist = {
Expand Down Expand Up @@ -107,6 +117,99 @@ def test_different_counts_lengths(self):
fig = plot_histogram([raw_dist, exact_dist])
self.assertIsInstance(fig, mpl.figure.Figure)

@unittest.skipIf(not HAS_MATPLOTLIB, "matplotlib not available.")
def test_plot_histogram_bars_align(self):
"""Test issue #6692"""
data_noisy = {
"00000": 0.22,
"00001": 0.003,
"00010": 0.005,
"00100": 0.004,
"00101": 0.001,
"00110": 0.004,
"00111": 0.001,
"01000": 0.005,
"01010": 0.002,
"01100": 0.225,
"01101": 0.001,
"01110": 0.003,
"01111": 0.003,
"10000": 0.012,
"10001": 0.002,
"10010": 0.001,
"10011": 0.001,
"10100": 0.247,
"10101": 0.004,
"10110": 0.003,
"10111": 0.001,
"11000": 0.225,
"11001": 0.005,
"11010": 0.002,
"11100": 0.015,
"11101": 0.004,
"11110": 0.001,
}
data_ideal = {
"00000": 0.25,
"01100": 0.25,
"10100": 0.25,
"11000": 0.25,
}
data_ideal_reduced = {
"00000": 0.25,
"01100": 0.25,
"10100": 0.25,
"11000": 0.25,
"rest": 0,
}
data_noisy_reduced = {
"00000": 0.22,
"01100": 0.225,
"10100": 0.247,
"11000": 0.225,
"11100": 0.015,
"rest": 0.083,
}
fig_reduced = plot_histogram([data_ideal, data_noisy], number_to_keep=5)
fig_manual_reduced = plot_histogram([data_ideal_reduced, data_noisy_reduced])
self.assertIsInstance(fig_reduced, mpl.figure.Figure)

# Check images nearly match (ordering of bars is different)
# img_ref = path_to_diagram_reference("plot_histogram_reduced_states.png")
# img_manual_ref = path_to_diagram_reference("plot_histogram_reduced_states_manual.png")
with BytesIO() as img_buffer:
fig_reduced.savefig(img_buffer, format="png")
img_buffer.seek(0)
with BytesIO() as img_manual_buffer:
fig_manual_reduced.savefig(img_manual_buffer, format="png")
img_manual_buffer.seek(0)
self.assertImagesAreEqual(
Image.open(img_buffer), Image.open(img_manual_buffer), 0.01
)
# self.assertImagesAreEqual(Image.open(img_manual_buffer), img_manual_ref, 0.01)
# self.assertImagesAreEqual(Image.open(img_buffer), img_ref, 0.01)
plt.close(fig_reduced)
plt.close(fig_manual_reduced)

@unittest.skipIf(not HAS_MATPLOTLIB, "matplotlib not available.")
def test_plot_histogram_number_to_keep(self):
"""Test that histograms using number_to_keep produce outputs."""
data_ideal = {
"000": 0.25,
"110": 0.25,
"011": 0.25,
"101": 0.25,
}
data_noisy = {
"000": 0.24,
"110": 0.25,
"011": 0.24,
"101": 0.24,
"001": 0.03,
}
fig_few_items = plot_histogram([data_ideal, data_noisy], number_to_keep=4)
self.assertIsInstance(fig_few_items, mpl.figure.Figure)


if __name__ == "__main__":
unittest.main(verbosity=2)