diff --git a/src/torchmetrics/utilities/plot.py b/src/torchmetrics/utilities/plot.py index 664d718723b..b1ec17597b7 100644 --- a/src/torchmetrics/utilities/plot.py +++ b/src/torchmetrics/utilities/plot.py @@ -315,7 +315,7 @@ def plot_curve( for i, (x_, y_) in enumerate(zip(x, y)): label = f"{legend_name}_{i}" if legend_name is not None else str(i) label += f" AUC={score[i].item():0.3f}" if score is not None else "" - ax.plot(x_.detach().cpu(), y_.detach().cpu(), label=label) + ax.plot(x_.detach().cpu(), y_.detach().cpu(), linestyle="-", linewidth=2, label=label) ax.legend() else: raise ValueError(