diff --git a/changelog.md b/changelog.md index 9ae5fc36..6b04e322 100644 --- a/changelog.md +++ b/changelog.md @@ -2,6 +2,8 @@ ### [Latest] +- Add `VarVsVar` plot [!172](https://github.com/umami-hep/puma/pull/172) + ### [v0.2.4] (2023/04/06) - Replace `dijets` category with `QCD` category [!170](https://github.com/umami-hep/puma/pull/170) diff --git a/puma/__init__.py b/puma/__init__.py index 034632d3..f37493bf 100644 --- a/puma/__init__.py +++ b/puma/__init__.py @@ -11,3 +11,4 @@ from puma.plot_base import PlotBase, PlotLineObject, PlotObject from puma.roc import Roc, RocPlot from puma.var_vs_eff import VarVsEff, VarVsEffPlot +from puma.var_vs_var import VarVsVar, VarVsVarPlot diff --git a/puma/line_plot_2d.py b/puma/line_plot_2d.py index 490dae8d..f39c11ce 100644 --- a/puma/line_plot_2d.py +++ b/puma/line_plot_2d.py @@ -4,7 +4,7 @@ import pandas as pd from puma.plot_base import PlotBase, PlotLineObject -from puma.utils import get_good_colours, logger +from puma.utils import get_good_colours, get_good_markers, logger class Line2D(PlotLineObject): # pylint: disable=too-few-public-methods @@ -183,7 +183,7 @@ def add( curve.colour = self.plot_objects[len(self.plot_objects)].colour # Set markerstyle if curve.marker is None: - curve.marker = "x" + curve.marker = get_good_markers()[len(self.plot_objects)] # Set markersize if curve.markersize is None: curve.markersize = 15 diff --git a/puma/tests/expected_plots/test_var_vs_var.png b/puma/tests/expected_plots/test_var_vs_var.png new file mode 100644 index 00000000..61631636 Binary files /dev/null and b/puma/tests/expected_plots/test_var_vs_var.png differ diff --git a/puma/tests/test_var_vs_var.py b/puma/tests/test_var_vs_var.py new file mode 100644 index 00000000..6a80d770 --- /dev/null +++ b/puma/tests/test_var_vs_var.py @@ -0,0 +1,214 @@ +#!/usr/bin/env python + + +""" +Unit test script for the functions in var_vs_var.py +""" + +import os +import tempfile +import unittest + +import numpy as np +from matplotlib.testing.compare import compare_images + +from puma import VarVsVar, VarVsVarPlot +from puma.utils.logging import logger, set_log_level + +set_log_level(logger, "DEBUG") + + +class VarVsVarTestCase(unittest.TestCase): + """Test class for the puma.var_vs_var functions.""" + + def setUp(self): + self.x_var = np.linspace(100, 250, 20) + self.y_var_mean = np.exp(-np.linspace(6, 10, 20)) * 10e3 + self.y_var_std = np.sin(self.y_var_mean) + + def test_var_vs_var_init_wrong_mean_shape(self): + """Test var_vs_var init.""" + with self.assertRaises(ValueError): + VarVsVar(np.ones(4), np.ones(5), np.ones(5)) + + def test_var_vs_var_init_wrong_y_mean_std_shape(self): + """Test var_vs_var init.""" + with self.assertRaises(ValueError): + VarVsVar(np.ones(4), np.ones(4), np.ones(5)) + + def test_var_vs_var_init_wrong_x_mean_widths_shape(self): + """Test var_vs_var init.""" + with self.assertRaises(ValueError): + VarVsVar(np.ones(4), np.ones(4), np.ones(4), x_var_widths=np.ones(5)) + + def test_var_vs_var_init(self): + """Test var_vs_var init.""" + VarVsVar( + np.ones(6), + np.ones(6), + np.ones(6), + x_var_widths=np.ones(6), + key="test", + fill=True, + plot_y_std=False, + ) + + def test_var_vs_var_eq(self): + """Test var_vs_var eq.""" + var_plot = VarVsVar( + np.ones(6), + np.ones(6), + np.ones(6), + x_var_widths=np.ones(6), + key="test", + fill=True, + plot_y_std=False, + ) + self.assertEqual(var_plot, var_plot) + + def test_var_vs_var_eq_different_classes(self): + """Test var_vs_var eq.""" + var_plot = VarVsVar( + np.ones(6), + np.ones(6), + np.ones(6), + x_var_widths=np.ones(6), + key="test", + fill=True, + plot_y_std=False, + ) + self.assertNotEqual(var_plot, np.ones(6)) + + def test_var_vs_var_divide_same(self): + """Test var_vs_var divide.""" + var_plot = VarVsVar( + x_var=self.x_var, + y_var_mean=self.y_var_mean, + y_var_std=self.y_var_std, + ) + np.testing.assert_array_almost_equal(var_plot.divide(var_plot)[0], np.ones(20)) + + def test_var_vs_var_divide_different_shapes(self): + """Test var_vs_eff divide.""" + var_plot = VarVsVar( + x_var=self.x_var, + y_var_mean=self.y_var_mean, + y_var_std=self.y_var_std, + ) + var_plot_comp = VarVsVar( + x_var=np.repeat(self.x_var, 2), + y_var_mean=np.repeat(self.y_var_mean, 2), + y_var_std=np.repeat(self.y_var_std, 2), + ) + with self.assertRaises(ValueError): + var_plot.divide(var_plot_comp) + + +class VarVsVarPlotTestCase( + unittest.TestCase +): # pylint:disable=too-many-instance-attributes + """Test class for the puma.var_vs_var_plot""" + + def setUp(self): + # Set up temp directory for comparison plots + self.tmp_dir = tempfile.TemporaryDirectory() # pylint:disable=R1732 + self.actual_plots_dir = f"{self.tmp_dir.name}/" + self.expected_plots_dir = os.path.join( + os.path.dirname(__file__), "expected_plots" + ) + np.random.seed(42) + n_random = 21 + + # background (same for both taggers) + self.x_var = np.linspace(0, 250, num=n_random) + self.y_var_mean = np.exp(-self.x_var / 200) * 10 + self.y_var_std = np.sin(self.y_var_mean) + self.x_var_widths = np.ones_like(self.x_var) * 5 + + self.y_var_mean_2 = np.exp(-self.x_var / 100) * 10 + self.y_var_std_2 = np.sin(self.y_var_mean_2) + + self.test = VarVsVar( + x_var=self.x_var, + y_var_mean=self.y_var_mean, + y_var_std=self.y_var_std, + label=r"$10e^{-x/200}$", + fill=False, + is_marker=True, + ) + self.test_2 = VarVsVar( + x_var=self.x_var, + y_var_mean=self.y_var_mean_2, + y_var_std=self.y_var_std_2, + x_var_widths=self.x_var_widths, + label=r"$10e^{-x/100}$", + fill=True, + plot_y_std=False, + is_marker=False, + ) + + def test_n_ratio_panels(self): + """Check if ValueError is raised when we require more than 1 ratio panel""" + with self.assertRaises(ValueError): + VarVsVarPlot( + n_ratio_panels=np.random.randint(2, 10), + ) + + def test_no_reference(self): + """Check if ValueError is raised when plot ratios without reference""" + test_plot = VarVsVarPlot( + n_ratio_panels=1, + ) + test_plot.add(self.test) + test_plot.add(self.test_2) + with self.assertRaises(ValueError): + test_plot.plot_ratios() + + def test_overwrite_reference(self): + """Check correct reference overwrite""" + test_plot = VarVsVarPlot( + n_ratio_panels=1, + ) + test_plot.add(self.test, reference=True) + test_plot.add(self.test_2, reference=True) + + def test_same_keys(self): + """Check if KeyError is rased when we add VarVsVar object with existing key""" + test_plot = VarVsVarPlot( + n_ratio_panels=1, + ) + test_plot.add(self.test, key="1") + with self.assertRaises(KeyError): + test_plot.add(self.test_2, key="1") + + def test_output_plot(self): + """Test output plot.""" + # define the curves + + test_plot = VarVsVarPlot( + ylabel=r"$\overline{N}_{trk}$", + xlabel=r"$p_{T}$ [GeV]", + grid=True, + logy=False, + atlas_second_tag="Unit test plot based on exponential decay.", + n_ratio_panels=1, + figsize=(9, 6), + ) + test_plot.add(self.test, reference=True) + test_plot.add(self.test_2, reference=False) + + test_plot.draw_hline(4) + test_plot.draw() + + plotname = "test_var_vs_var.png" + test_plot.savefig(f"{self.actual_plots_dir}/{plotname}") + # Uncomment line below to update expected image + # test_plot.savefig(f"{self.expected_plots_dir}/{plotname}") + self.assertEqual( + None, + compare_images( + f"{self.actual_plots_dir}/{plotname}", + f"{self.expected_plots_dir}/{plotname}", + tol=5, + ), + ) diff --git a/puma/utils/__init__.py b/puma/utils/__init__.py index 69857533..7c67fd8b 100644 --- a/puma/utils/__init__.py +++ b/puma/utils/__init__.py @@ -139,6 +139,27 @@ def get_good_colours(colour_scheme=None): return Dark2_8.mpl_colors +def get_good_markers(): + """List of markers adequate for plotting + + Returns + ------- + list + list with markers + """ + # TODO needs improvements + + return [ + "o", # Circle + "x", # x + "v", # Triangle down + "^", # Triangle up + "D", # Diamond + "p", # Pentagon + "s", # Square + ] + + def get_good_linestyles(names=None): """Returns a list of good linestyles diff --git a/puma/var_vs_var.py b/puma/var_vs_var.py new file mode 100644 index 00000000..7a4690bb --- /dev/null +++ b/puma/var_vs_var.py @@ -0,0 +1,417 @@ +"""Variable vs another variable plot""" +import matplotlib as mpl +import numpy as np +from matplotlib.patches import Rectangle + +# TODO: fix the import below +from puma.plot_base import PlotBase, PlotLineObject +from puma.utils import get_good_colours, get_good_markers, logger +from puma.utils.histogram import hist_ratio + + +class VarVsVar(PlotLineObject): # pylint: disable=too-many-instance-attributes + """ + VarVsVar class storing info about curve and allows to calculate ratio w.r.t other + efficiency plots. + """ + + def __init__( # pylint: disable=too-many-arguments + self, + x_var: np.ndarray, + y_var_mean: np.ndarray, + y_var_std: np.ndarray, + x_var_widths: np.ndarray = None, + key: str = None, + fill: bool = True, + plot_y_std: bool = True, + **kwargs, + ) -> None: + """Initialise properties of VarVsVar curve object. + + Parameters + ---------- + x_var : np.ndarray + Values for x-axis variable, e.g. bin midpoints for binned data + y_var_mean : np.ndarray + Mean value for y-axis variable + y_var_std : np.ndarray + Std value for y-axis variable + x_var_widths : np.ndarray, optional + Widths for x-axis variable, e.g. bin widths for binned data + key : str, optional + Identifier for the curve e.g. tagger, by default None + fill : bool, optional + Defines do we need to fill box around point, by default True + plot_y_std : bool, optional + Defines do we need to plot y_var_std, by default True + **kwargs : kwargs + Keyword arguments passed to `PlotLineObject` + + Raises + ------ + ValueError + If provided options are not compatible with each other + """ + + super().__init__(**kwargs) + if len(x_var) != len(y_var_mean): + raise ValueError( + f"Length of `x_var` ({len(x_var)}) and `y_var_mean` " + f"({len(y_var_mean)}) have to be identical." + ) + if len(x_var) != len(y_var_std): + raise ValueError( + f"Length of `x_var` ({len(x_var)}) and `y_var_std` " + f"({len(y_var_std)}) have to be identical." + ) + if x_var_widths is not None and len(x_var) != len(x_var_widths): + raise ValueError( + f"Length of `x_var` ({len(x_var)}) and `x_var_widths` " + f"({len(x_var_widths)}) have to be identical." + ) + self.x_var = np.array(x_var) + self.x_var_widths = None if x_var_widths is None else np.array(x_var_widths) + self.y_var_mean = np.array(y_var_mean) + self.y_var_std = np.array(y_var_std) + + self.key = key + self.fill = fill + self.plot_y_std = plot_y_std + + def __eq__(self, other): + if isinstance(other, self.__class__): + return ( + np.all(self.x_var == other.x_var) + and np.all(self.y_var_mean == other.y_var_mean) + and np.all(self.y_var_std == other.y_var_std) + and self.key == other.key + ) + return False + + def divide(self, other, inverse: bool = False): + """Calculate ratio between two class objects. + + Parameters + ---------- + other : VarVsVar class + Second VarVsVar object to calculate ratio with + inverse : bool + If False the ratio is calculated `this / other`, + if True the inverse is calculated + + Returns + ------- + np.ndarray + Ratio + np.ndarray + Ratio error + + Raises + ------ + ValueError + If binning is not identical between 2 objects + """ + if not np.array_equal(self.x_var, other.x_var): + raise ValueError("The x variables of the two given objects do not match.") + nom, nom_err = self.y_var_mean, self.y_var_std + denom, denom_err = other.y_var_mean, other.y_var_std + + ratio, ratio_err = hist_ratio( + numerator=denom if inverse else nom, + denominator=nom if inverse else denom, + numerator_unc=denom_err if inverse else nom_err, + denominator_unc=nom_err if inverse else denom_err, + step=False, + ) + return (ratio, ratio_err) + + +class VarVsVarPlot(PlotBase): # pylint: disable=too-many-instance-attributes + """var_vs_eff plot class""" + + def __init__(self, grid: bool = False, **kwargs) -> None: + """var_vs_eff plot properties + + Parameters + ---------- + grid : bool, optional + Set the grid for the plots. + **kwargs : kwargs + Keyword arguments from `puma.PlotObject` + + Raises + ------ + ValueError + If incompatible mode given or more than 1 ratio panel requested + """ + super().__init__(grid=grid, **kwargs) + + self.plot_objects = {} + self.add_order = [] + self.ratios_objects = {} + self.reference_object = None + self.x_var_min = np.inf + self.x_var_max = -np.inf + self.inverse_cut = False + if self.n_ratio_panels > 1: + raise ValueError("Not more than one ratio panel supported.") + self.initialise_figure() + + def add( + self, curve: VarVsVar, key: str = None, reference: bool = False + ): # pylint: disable=too-many-branches + """Adding VarVsVar object to figure. + + Parameters + ---------- + curve : VarVsVar class + VarVsVar curve + key : str, optional + Unique identifier for VarVsVar curve, by default None + reference : bool, optional + If VarVsVar is used as reference for ratio calculation, by default False + + Raises + ------ + KeyError + If unique identifier key is used twice + """ + if key is None: + key = len(self.plot_objects) + 1 + if key in self.plot_objects: + raise KeyError(f"Duplicated key {key} already used for unique identifier.") + + self.plot_objects[key] = curve + self.add_order.append(key) + # set linestyle + if curve.linestyle is None: + curve.linestyle = "-" + # set colours + if curve.colour is None: + curve.colour = get_good_colours()[len(self.plot_objects) - 1] + # set alpha + if curve.alpha is None: + curve.alpha = 0.8 + # set linewidth + if curve.linewidth is None: + curve.linewidth = 1.6 + + if curve.is_marker is True: + if curve.marker is None: + curve.marker = get_good_markers()[len(self.plot_objects)] + # Set markersize + if curve.markersize is None: + curve.markersize = 15 + if curve.markeredgewidth is None: + curve.markeredgewidth = 2 + + # set min and max edges + if curve.x_var_widths is not None: + left_edge = curve.x_var - curve.x_var_widths / 2 + right_edge = curve.x_var + curve.x_var_widths / 2 + else: + left_edge = curve.x_var + right_edge = curve.x_var + self.x_var_min = min(self.x_var_min, np.sort(left_edge)[0]) + self.x_var_max = max(self.x_var_max, np.sort(right_edge)[-1]) + + if reference: + logger.debug("Setting roc %s as reference.", key) + self.set_reference(key) + + def set_reference(self, key: str): + """Setting the reference roc curves used in the ratios + + Parameters + ---------- + key : str + Unique identifier of roc object + """ + if self.reference_object is None: + self.reference_object = key + else: + logger.warning( + ( + "You specified a second curve %s as reference for ratio. " + "Using it as new reference instead of %s." + ), + key, + self.reference_object, + ) + self.reference_object = key + + def plot(self, **kwargs): + """Plotting curves + + Parameters + ---------- + **kwargs: kwargs + Keyword arguments passed to plt.axis.errorbar + + Returns + ------- + Line2D + matplotlib Line2D object + """ + logger.debug("Plotting curves") + plt_handles = [] + for key in self.add_order: + elem = self.plot_objects[key] + error_bar = self.axis_top.errorbar( + elem.x_var, + elem.y_var_mean, + xerr=elem.x_var_widths / 2 if elem.x_var_widths is not None else None, + yerr=(elem.y_var_std if elem.plot_y_std else np.zeros_like(elem.x_var)), + color=elem.colour, + fmt="none", + label=elem.label, + alpha=elem.alpha, + linewidth=elem.linewidth, + **kwargs, + ) + # # set linestyle for errorbar + error_bar[-1][0].set_linestyle(elem.linestyle) + # Draw markers + if elem.is_marker is True: + self.axis_top.scatter( + x=elem.x_var, + y=elem.y_var_mean, + marker=elem.marker, + color=elem.colour, + ) + if elem.x_var_widths is not None and elem.fill: + for x_pos, y_pos, width, height in zip( + elem.x_var, + elem.y_var_mean, + elem.x_var_widths, + 2 * elem.y_var_std, + ): + self.axis_top.add_patch( + Rectangle( + xy=( + x_pos - width / 2, + y_pos - height / 2, + ), + width=width, + height=height, + linewidth=0, + color=elem.colour, + alpha=0.3, + zorder=1, + ) + ) + plt_handles.append( + mpl.lines.Line2D( + [], + [], + color=elem.colour, + label=elem.label, + linestyle=elem.linestyle, + marker=elem.marker, + ) + ) + return plt_handles + + def plot_ratios(self): + """Plotting ratio curves. + + Raises + ------ + ValueError + If no reference curve is defined + """ + if self.reference_object is None: + raise ValueError("Please specify a reference curve.") + for key in self.add_order: + elem = self.plot_objects[key] + (ratio, ratio_err) = elem.divide(self.plot_objects[self.reference_object]) + error_bar = self.ratio_axes[0].errorbar( + elem.x_var, + ratio, + xerr=elem.x_var_widths / 2 if elem.x_var_widths is not None else None, + yerr=ratio_err if elem.plot_y_std else np.zeros_like(elem.x_var), + color=elem.colour, + fmt="none", + alpha=elem.alpha, + linewidth=elem.linewidth, + ) + # set linestyle for errorbar + error_bar[-1][0].set_linestyle(elem.linestyle) + # draw markers + if elem.is_marker is True: + self.ratio_axes[0].scatter( + x=elem.x_var, y=ratio, marker=elem.marker, color=elem.colour + ) + if elem.x_var_widths is not None and elem.fill: + for x_pos, y_pos, width, height in zip( + elem.x_var, ratio, elem.x_var_widths, 2 * ratio_err + ): + self.ratio_axes[0].add_patch( + Rectangle( + xy=( + x_pos - width / 2, + y_pos - height / 2, + ), + width=width, + height=height, + linewidth=0, + color=elem.colour, + alpha=0.3, + zorder=1, + ) + ) + + def draw_hline(self, y_val: float): + """Draw hline in top plot panel. + + Parameters + ---------- + y_val : float + y value of the horizontal line + """ + self.axis_top.hlines( + y=y_val, + xmin=self.x_var_min, + xmax=self.x_var_max, + colors="black", + linestyle="dotted", + alpha=0.5, + ) + + def draw( + self, + labelpad: int = None, + ): + """Draw figure. + + Parameters + ---------- + labelpad : int, optional + Spacing in points from the axes bounding box including + ticks and tick labels, by default "ratio" + """ + self.set_xlim( + self.x_var_min if self.xmin is None else self.xmin, + self.x_var_max if self.xmax is None else self.xmax, + ) + plt_handles = self.plot() + if self.n_ratio_panels == 1: + self.plot_ratios() + self.set_title() + self.set_log() + self.set_y_lim() + self.set_xlabel() + self.set_tick_params() + self.set_ylabel(self.axis_top) + + if self.n_ratio_panels > 0: + self.set_ylabel( + self.ratio_axes[0], + self.ylabel_ratio[0], + align_right=False, + labelpad=labelpad, + ) + self.make_legend(plt_handles, ax_mpl=self.axis_top) + self.plotting_done = True + if self.apply_atlas_style is True: + self.atlasify(use_tag=self.use_atlas_tag)