diff --git a/src/mplhep/__init__.py b/src/mplhep/__init__.py index b1551a55..7768aca4 100644 --- a/src/mplhep/__init__.py +++ b/src/mplhep/__init__.py @@ -18,6 +18,7 @@ hist2dplot, histplot, make_square_add_cbar, + merge_legend_handles_labels, mpl_magic, rescale_to_axessize, sort_legend, @@ -70,6 +71,7 @@ "rescale_to_axessize", "box_aspect", "make_square_add_cbar", + "merge_legend_handles_labels", "append_axes", "sort_legend", "save_variations", diff --git a/src/mplhep/plot.py b/src/mplhep/plot.py index 8d46fd0d..8ae57be2 100644 --- a/src/mplhep/plot.py +++ b/src/mplhep/plot.py @@ -1408,3 +1408,28 @@ def sort_legend(ax, order=None): if isinstance(order, OrderedDict): ordered_label_list = [order[k] for k in ordered_label_list] return ordered_label_values, ordered_label_list + + +def merge_legend_handles_labels(handles, labels): + """ + Merge handles for identical labels. + This is useful when combining multiple plot functions into a single label. + + handles : List of handles + labels : List of labels + """ + + seen_labels = [] + seen_label_handles = [] + for handle, label in zip(handles, labels): + if label not in seen_labels: + seen_labels.append(label) + seen_label_handles.append([handle]) + else: + idx = seen_labels.index(label) + seen_label_handles[idx].append(handle) + + for i in range(len(seen_labels)): + seen_label_handles[i] = tuple(seen_label_handles[i]) + + return seen_label_handles, seen_labels