Skip to content

Commit

Permalink
Add plotting 11/n (#1621)
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki authored Mar 15, 2023
1 parent 898aef5 commit 1147b32
Show file tree
Hide file tree
Showing 8 changed files with 295 additions and 6 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
[#1605](https://github.com/Lightning-AI/metrics/pull/1605),
[#1610](https://github.com/Lightning-AI/metrics/pull/1610),
[#1609](https://github.com/Lightning-AI/metrics/pull/1609),
[#1621](https://github.com/Lightning-AI/metrics/pull/1621),
)


Expand Down
48 changes: 47 additions & 1 deletion src/torchmetrics/regression/pearson.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Tuple
from typing import Any, List, Optional, Sequence, Tuple, Union

import torch
from torch import Tensor

from torchmetrics.functional.regression.pearson import _pearson_corrcoef_compute, _pearson_corrcoef_update
from torchmetrics.metric import Metric
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["PearsonCorrCoef.plot"]


def _final_aggregation(
Expand Down Expand Up @@ -159,3 +164,44 @@ def compute(self) -> Tensor:
corr_xy = self.corr_xy
n_total = self.n_total
return _pearson_corrcoef_compute(var_x, var_y, corr_xy, n_total)

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis
Returns:
Figure and Axes object
Raises:
ModuleNotFoundError:
If `matplotlib` is not installed
.. plot::
:scale: 75
>>> from torch import randn
>>> # Example plotting a single value
>>> from torchmetrics.regression import PearsonCorrCoef
>>> metric = PearsonCorrCoef()
>>> metric.update(randn(10,), randn(10,))
>>> fig_, ax_ = metric.plot()
.. plot::
:scale: 75
>>> from torch import randn
>>> # Example plotting multiple values
>>> from torchmetrics.regression import PearsonCorrCoef
>>> metric = PearsonCorrCoef()
>>> values = []
>>> for _ in range(10):
... values.append(metric(randn(10,), randn(10,)))
>>> fig, ax = metric.plot(values)
"""
return self._plot(val, ax)
48 changes: 47 additions & 1 deletion src/torchmetrics/regression/r2.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from typing import Any, Optional, Sequence, Union

import torch
from torch import Tensor, tensor

from torchmetrics.functional.regression.r2 import _r2_score_compute, _r2_score_update
from torchmetrics.metric import Metric
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["R2Score.plot"]


class R2Score(Metric):
Expand Down Expand Up @@ -129,3 +134,44 @@ def compute(self) -> Tensor:
return _r2_score_compute(
self.sum_squared_error, self.sum_error, self.residual, self.total, self.adjusted, self.multioutput
)

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis
Returns:
Figure and Axes object
Raises:
ModuleNotFoundError:
If `matplotlib` is not installed
.. plot::
:scale: 75
>>> from torch import randn
>>> # Example plotting a single value
>>> from torchmetrics.regression import R2Score
>>> metric = R2Score()
>>> metric.update(randn(10,), randn(10,))
>>> fig_, ax_ = metric.plot()
.. plot::
:scale: 75
>>> from torch import randn
>>> # Example plotting multiple values
>>> from torchmetrics.regression import R2Score
>>> metric = R2Score()
>>> values = []
>>> for _ in range(10):
... values.append(metric(randn(10,), randn(10,)))
>>> fig, ax = metric.plot(values)
"""
return self._plot(val, ax)
48 changes: 47 additions & 1 deletion src/torchmetrics/regression/spearman.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List
from typing import Any, List, Optional, Sequence, Union

from torch import Tensor

from torchmetrics.functional.regression.spearman import _spearman_corrcoef_compute, _spearman_corrcoef_update
from torchmetrics.metric import Metric
from torchmetrics.utilities import rank_zero_warn
from torchmetrics.utilities.data import dim_zero_cat
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["SpearmanCorrCoef.plot"]


class SpearmanCorrCoef(Metric):
Expand Down Expand Up @@ -95,3 +100,44 @@ def compute(self) -> Tensor:
preds = dim_zero_cat(self.preds)
target = dim_zero_cat(self.target)
return _spearman_corrcoef_compute(preds, target)

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis
Returns:
Figure and Axes object
Raises:
ModuleNotFoundError:
If `matplotlib` is not installed
.. plot::
:scale: 75
>>> from torch import randn
>>> # Example plotting a single value
>>> from torchmetrics.regression import SpearmanCorrCoef
>>> metric = SpearmanCorrCoef()
>>> metric.update(randn(10,), randn(10,))
>>> fig_, ax_ = metric.plot()
.. plot::
:scale: 75
>>> from torch import randn
>>> # Example plotting multiple values
>>> from torchmetrics.regression import SpearmanCorrCoef
>>> metric = SpearmanCorrCoef()
>>> values = []
>>> for _ in range(10):
... values.append(metric(randn(10,), randn(10,)))
>>> fig, ax = metric.plot(values)
"""
return self._plot(val, ax)
48 changes: 47 additions & 1 deletion src/torchmetrics/regression/symmetric_mape.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from typing import Any, Optional, Sequence, Union

from torch import Tensor, tensor

Expand All @@ -20,6 +20,11 @@
_symmetric_mean_absolute_percentage_error_update,
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["SymmetricMeanAbsolutePercentageError.plot"]


class SymmetricMeanAbsolutePercentageError(Metric):
Expand Down Expand Up @@ -74,3 +79,44 @@ def update(self, preds: Tensor, target: Tensor) -> None:
def compute(self) -> Tensor:
"""Compute mean absolute percentage error over state."""
return _symmetric_mean_absolute_percentage_error_compute(self.sum_abs_per_error, self.total)

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis
Returns:
Figure and Axes object
Raises:
ModuleNotFoundError:
If `matplotlib` is not installed
.. plot::
:scale: 75
>>> from torch import randn
>>> # Example plotting a single value
>>> from torchmetrics.regression import SymmetricMeanAbsolutePercentageError
>>> metric = SymmetricMeanAbsolutePercentageError()
>>> metric.update(randn(10,), randn(10,))
>>> fig_, ax_ = metric.plot()
.. plot::
:scale: 75
>>> from torch import randn
>>> # Example plotting multiple values
>>> from torchmetrics.regression import SymmetricMeanAbsolutePercentageError
>>> metric = SymmetricMeanAbsolutePercentageError()
>>> values = []
>>> for _ in range(10):
... values.append(metric(randn(10,), randn(10,)))
>>> fig, ax = metric.plot(values)
"""
return self._plot(val, ax)
48 changes: 47 additions & 1 deletion src/torchmetrics/regression/tweedie_deviance.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from typing import Any, Optional, Sequence, Union

import torch
from torch import Tensor
Expand All @@ -21,6 +21,11 @@
_tweedie_deviance_score_update,
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["TweedieDevianceScore.plot"]


class TweedieDevianceScore(Metric):
Expand Down Expand Up @@ -99,3 +104,44 @@ def update(self, preds: Tensor, targets: Tensor) -> None:
def compute(self) -> Tensor:
"""Compute metric."""
return _tweedie_deviance_score_compute(self.sum_deviance_score, self.num_observations)

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis
Returns:
Figure and Axes object
Raises:
ModuleNotFoundError:
If `matplotlib` is not installed
.. plot::
:scale: 75
>>> from torch import randn
>>> # Example plotting a single value
>>> from torchmetrics.regression import TweedieDevianceScore
>>> metric = TweedieDevianceScore()
>>> metric.update(randn(10,), randn(10,))
>>> fig_, ax_ = metric.plot()
.. plot::
:scale: 75
>>> from torch import randn
>>> # Example plotting multiple values
>>> from torchmetrics.regression import TweedieDevianceScore
>>> metric = TweedieDevianceScore()
>>> values = []
>>> for _ in range(10):
... values.append(metric(randn(10,), randn(10,)))
>>> fig, ax = metric.plot(values)
"""
return self._plot(val, ax)
Loading

0 comments on commit 1147b32

Please sign in to comment.