diff --git a/qiskit/visualization/counts_visualization.py b/qiskit/visualization/counts_visualization.py index cde52c275273..c74e6df96397 100644 --- a/qiskit/visualization/counts_visualization.py +++ b/qiskit/visualization/counts_visualization.py @@ -247,8 +247,8 @@ def _plot_histogram_data(data, labels, number_to_keep): """Generate the data needed for plotting counts. Parameters: - data (list or dict): This is either a list of dictionaries or a single - dict containing the values to represent (ex {'001': 130}) + data (list of dict): This is a list of dictionaries (or a list of a single + dict) containing the values to represent (ex {'001': 130}) labels (list): The list of bitstring labels for the plot. number_to_keep (int): The number of terms to plot and rest is made into a single bar called 'rest'. @@ -264,10 +264,27 @@ def _plot_histogram_data(data, labels, number_to_keep): all_pvalues = [] all_inds = [] + # pre-calculate the items that will be kept, using the normalized dictionaries + # equally weights each dictionary + if number_to_keep is not None: + item_counter = Counter() + for execution in data: + values_total = sum(execution.values()) + normalized_execution = {k: v / values_total for k, v in execution.items()} + item_counter.update(normalized_execution) + keys_to_keep = list(map(lambda elem: elem[0], item_counter.most_common(number_to_keep))) + [ + "rest" + ] + labels = keys_to_keep + for execution in data: if number_to_keep is not None: - data_temp = dict(Counter(execution).most_common(number_to_keep)) - data_temp["rest"] = sum(execution.values()) - sum(data_temp.values()) + data_temp = {k: execution.get(k, 0) for k in labels} + # data_temp = filter(lambda elem: elem[0] in keys_to_keep, execution.items())) + if "rest" in execution.keys(): + data_temp["rest"] = execution["rest"] + else: + data_temp["rest"] = sum(execution.values()) - sum(data_temp.values()) execution = data_temp values = [] for key in labels: @@ -281,6 +298,8 @@ def _plot_histogram_data(data, labels, number_to_keep): labels_dict[key] = 1 values.append(execution[key]) values = np.array(values, dtype=float) + if number_to_keep is not None: + assert len(values) <= number_to_keep + 1 pvalues = values / sum(values) all_pvalues.append(pvalues) numelem = len(values) diff --git a/test/python/visualization/test_plot_histogram.py b/test/python/visualization/test_plot_histogram.py index aca1e02061a2..92aad7b719ed 100644 --- a/test/python/visualization/test_plot_histogram.py +++ b/test/python/visualization/test_plot_histogram.py @@ -13,15 +13,25 @@ """Tests for plot_histogram.""" import unittest -import matplotlib as mpl -from qiskit.test import QiskitTestCase -from qiskit.tools.visualization import plot_histogram +from io import BytesIO +from PIL import Image +from qiskit.tools.visualization import plot_histogram, HAS_MATPLOTLIB -class TestPlotHistogram(QiskitTestCase): +# from .visualization import path_to_diagram_reference +from .visualization import QiskitVisualizationTestCase + + +if HAS_MATPLOTLIB: + import matplotlib as mpl + import matplotlib.pyplot as plt + + +class TestPlotHistogram(QiskitVisualizationTestCase): """Qiskit plot_histogram tests.""" + @unittest.skipIf(not HAS_MATPLOTLIB, "matplotlib not available.") def test_different_counts_lengths(self): """Test plotting two different length dists works""" exact_dist = { @@ -107,6 +117,99 @@ def test_different_counts_lengths(self): fig = plot_histogram([raw_dist, exact_dist]) self.assertIsInstance(fig, mpl.figure.Figure) + @unittest.skipIf(not HAS_MATPLOTLIB, "matplotlib not available.") + def test_plot_histogram_bars_align(self): + """Test issue #6692""" + data_noisy = { + "00000": 0.22, + "00001": 0.003, + "00010": 0.005, + "00100": 0.004, + "00101": 0.001, + "00110": 0.004, + "00111": 0.001, + "01000": 0.005, + "01010": 0.002, + "01100": 0.225, + "01101": 0.001, + "01110": 0.003, + "01111": 0.003, + "10000": 0.012, + "10001": 0.002, + "10010": 0.001, + "10011": 0.001, + "10100": 0.247, + "10101": 0.004, + "10110": 0.003, + "10111": 0.001, + "11000": 0.225, + "11001": 0.005, + "11010": 0.002, + "11100": 0.015, + "11101": 0.004, + "11110": 0.001, + } + data_ideal = { + "00000": 0.25, + "01100": 0.25, + "10100": 0.25, + "11000": 0.25, + } + data_ideal_reduced = { + "00000": 0.25, + "01100": 0.25, + "10100": 0.25, + "11000": 0.25, + "rest": 0, + } + data_noisy_reduced = { + "00000": 0.22, + "01100": 0.225, + "10100": 0.247, + "11000": 0.225, + "11100": 0.015, + "rest": 0.083, + } + fig_reduced = plot_histogram([data_ideal, data_noisy], number_to_keep=5) + fig_manual_reduced = plot_histogram([data_ideal_reduced, data_noisy_reduced]) + self.assertIsInstance(fig_reduced, mpl.figure.Figure) + + # Check images nearly match (ordering of bars is different) + # img_ref = path_to_diagram_reference("plot_histogram_reduced_states.png") + # img_manual_ref = path_to_diagram_reference("plot_histogram_reduced_states_manual.png") + with BytesIO() as img_buffer: + fig_reduced.savefig(img_buffer, format="png") + img_buffer.seek(0) + with BytesIO() as img_manual_buffer: + fig_manual_reduced.savefig(img_manual_buffer, format="png") + img_manual_buffer.seek(0) + self.assertImagesAreEqual( + Image.open(img_buffer), Image.open(img_manual_buffer), 0.01 + ) + # self.assertImagesAreEqual(Image.open(img_manual_buffer), img_manual_ref, 0.01) + # self.assertImagesAreEqual(Image.open(img_buffer), img_ref, 0.01) + plt.close(fig_reduced) + plt.close(fig_manual_reduced) + + @unittest.skipIf(not HAS_MATPLOTLIB, "matplotlib not available.") + def test_plot_histogram_number_to_keep(self): + """Test that histograms using number_to_keep produce outputs.""" + data_ideal = { + "000": 0.25, + "110": 0.25, + "011": 0.25, + "101": 0.25, + } + data_noisy = { + "000": 0.24, + "110": 0.25, + "011": 0.24, + "101": 0.24, + "001": 0.03, + } + fig_few_items = plot_histogram([data_ideal, data_noisy], number_to_keep=4) + self.assertIsInstance(fig_few_items, mpl.figure.Figure) + if __name__ == "__main__": unittest.main(verbosity=2)