diff --git a/src/mplhep/plot.py b/src/mplhep/plot.py index 42e20f98..7a11ebb3 100644 --- a/src/mplhep/plot.py +++ b/src/mplhep/plot.py @@ -3,6 +3,7 @@ import collections.abc import inspect import warnings +import logging from collections import OrderedDict, namedtuple from typing import TYPE_CHECKING, Any, Union @@ -962,6 +963,10 @@ def _draw_leg_bbox(ax): """ fig = ax.figure leg = ax.get_legend() + if leg is None: + leg = [ + c for c in ax.get_children() if isinstance(c, plt.matplotlib.legend.Legend) + ][0] fig.canvas.draw() return leg.get_frame().get_bbox() @@ -973,6 +978,7 @@ def _draw_text_bbox(ax): """ fig = ax.figure textboxes = [k for k in ax.get_children() if isinstance(k, AnchoredText)] + fig.canvas.draw() if len(textboxes) > 1: print("Warning: More than one textbox found") for box in textboxes: @@ -984,15 +990,17 @@ def _draw_text_bbox(ax): return bbox -def yscale_legend(ax=None): +def yscale_legend(ax=None, otol=0): """ - Automatically scale y-axis up to fit in legend() + Automatically scale y-axis up to fit in legend(). + Set `otol > 0` for less strict scaling. """ if ax is None: ax = plt.gca() scale_factor = 10 ** (1.05) if ax.get_yscale() == "log" else 1.05 - while overlap(ax, _draw_leg_bbox(ax)) > 0: + while overlap(ax, _draw_leg_bbox(ax)) > otol: + logging.info("Scaling y-axis by 5% to fit legend") ax.set_ylim(ax.get_ylim()[0], ax.get_ylim()[-1] * scale_factor) ax.figure.canvas.draw() return ax