From b0c316a883799fab3f111b700384ced5eb7dbff6 Mon Sep 17 00:00:00 2001 From: baskrahmer Date: Tue, 19 Mar 2024 09:41:16 +0100 Subject: [PATCH 1/3] Always add axis label names if argument is specified --- src/torchmetrics/utilities/plot.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/torchmetrics/utilities/plot.py b/src/torchmetrics/utilities/plot.py index feb906810bc..664d718723b 100644 --- a/src/torchmetrics/utilities/plot.py +++ b/src/torchmetrics/utilities/plot.py @@ -307,9 +307,6 @@ def plot_curve( if isinstance(x, Tensor) and isinstance(y, Tensor) and x.ndim == 1 and y.ndim == 1: label = f"AUC={score.item():0.3f}" if score is not None else None ax.plot(x.detach().cpu(), y.detach().cpu(), linestyle="-", linewidth=2, label=label) - if label_names is not None: - ax.set_xlabel(label_names[0]) - ax.set_ylabel(label_names[1]) if label is not None: ax.legend() elif (isinstance(x, list) and isinstance(y, list)) or ( @@ -324,6 +321,9 @@ def plot_curve( raise ValueError( f"Unknown format for argument `x` and `y`. Expected either list or tensors but got {type(x)} and {type(y)}." ) + if label_names is not None: + ax.set_xlabel(label_names[0]) + ax.set_ylabel(label_names[1]) ax.grid(True) ax.set_title(name) From 42f3b3b4ad2fff2ffd07fdb689bf53ded5bf281f Mon Sep 17 00:00:00 2001 From: baskrahmer Date: Tue, 19 Mar 2024 09:41:29 +0100 Subject: [PATCH 2/3] Consistent plotting style --- src/torchmetrics/utilities/plot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/utilities/plot.py b/src/torchmetrics/utilities/plot.py index 664d718723b..b1ec17597b7 100644 --- a/src/torchmetrics/utilities/plot.py +++ b/src/torchmetrics/utilities/plot.py @@ -315,7 +315,7 @@ def plot_curve( for i, (x_, y_) in enumerate(zip(x, y)): label = f"{legend_name}_{i}" if legend_name is not None else str(i) label += f" AUC={score[i].item():0.3f}" if score is not None else "" - ax.plot(x_.detach().cpu(), y_.detach().cpu(), label=label) + ax.plot(x_.detach().cpu(), y_.detach().cpu(), linestyle="-", linewidth=2, label=label) ax.legend() else: raise ValueError( From ccaf36d76b8f930a5f66b12441fa4a38bb0dc99d Mon Sep 17 00:00:00 2001 From: Jirka Date: Tue, 19 Mar 2024 18:25:31 +0100 Subject: [PATCH 3/3] chlog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7e2bd17b325..8a9a7c49d0e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,7 +30,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed -- +- Fixed axis names with Precision-Recall curve ([#2462](https://github.com/Lightning-AI/torchmetrics/pull/2462)) ## [1.3.2] - 2024-03-18