Skip to content

Commit

Permalink
Add plotting 8/n (#1605)
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki authored Mar 10, 2023
1 parent 632e8e6 commit 2d21f6f
Show file tree
Hide file tree
Showing 8 changed files with 304 additions and 6 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
[#1585](https://github.com/Lightning-AI/metrics/pull/1585),
[#1593](https://github.com/Lightning-AI/metrics/pull/1593),
[#1600](https://github.com/Lightning-AI/metrics/pull/1600),
[#1605](https://github.com/Lightning-AI/metrics/pull/1605),
)


Expand Down
48 changes: 48 additions & 0 deletions src/torchmetrics/regression/concordance.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Sequence, Union

from torch import Tensor

from torchmetrics.functional.regression.concordance import _concordance_corrcoef_compute
from torchmetrics.regression.pearson import PearsonCorrCoef, _final_aggregation
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["ConcordanceCorrCoef.plot"]


class ConcordanceCorrCoef(PearsonCorrCoef):
Expand Down Expand Up @@ -74,3 +81,44 @@ def compute(self) -> Tensor:
corr_xy = self.corr_xy
n_total = self.n_total
return _concordance_corrcoef_compute(mean_x, mean_y, var_x, var_y, corr_xy, n_total)

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis
Returns:
Figure and Axes object
Raises:
ModuleNotFoundError:
If `matplotlib` is not installed
.. plot::
:scale: 75
>>> from torch import randn
>>> # Example plotting a single value
>>> from torchmetrics.regression import ConcordanceCorrCoef
>>> metric = ConcordanceCorrCoef()
>>> metric.update(randn(10,), randn(10,))
>>> fig_, ax_ = metric.plot()
.. plot::
:scale: 75
>>> from torch import randn
>>> # Example plotting multiple values
>>> from torchmetrics.regression import ConcordanceCorrCoef
>>> metric = ConcordanceCorrCoef()
>>> values = []
>>> for _ in range(10):
... values.append(metric(randn(10,), randn(10,)))
>>> fig, ax = metric.plot(values)
"""
return self._plot(val, ax)
48 changes: 47 additions & 1 deletion src/torchmetrics/regression/cosine_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List
from typing import Any, List, Optional, Sequence, Union

from torch import Tensor
from typing_extensions import Literal

from torchmetrics.functional.regression.cosine_similarity import _cosine_similarity_compute, _cosine_similarity_update
from torchmetrics.metric import Metric
from torchmetrics.utilities.data import dim_zero_cat
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["CosineSimilarity.plot"]


class CosineSimilarity(Metric):
Expand Down Expand Up @@ -84,3 +89,44 @@ def compute(self) -> Tensor:
preds = dim_zero_cat(self.preds)
target = dim_zero_cat(self.target)
return _cosine_similarity_compute(preds, target, self.reduction)

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis
Returns:
Figure and Axes object
Raises:
ModuleNotFoundError:
If `matplotlib` is not installed
.. plot::
:scale: 75
>>> from torch import randn
>>> # Example plotting a single value
>>> from torchmetrics.regression import CosineSimilarity
>>> metric = CosineSimilarity()
>>> metric.update(randn(10,), randn(10,))
>>> fig_, ax_ = metric.plot()
.. plot::
:scale: 75
>>> from torch import randn
>>> # Example plotting multiple values
>>> from torchmetrics.regression import CosineSimilarity
>>> metric = CosineSimilarity()
>>> values = []
>>> for _ in range(10):
... values.append(metric(randn(10,), randn(10,)))
>>> fig, ax = metric.plot(values)
"""
return self._plot(val, ax)
48 changes: 47 additions & 1 deletion src/torchmetrics/regression/explained_variance.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Sequence, Union
from typing import Any, Optional, Sequence, Union

from torch import Tensor, tensor
from typing_extensions import Literal
Expand All @@ -22,6 +22,11 @@
_explained_variance_update,
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["ExplainedVariance.plot"]


class ExplainedVariance(Metric):
Expand Down Expand Up @@ -121,3 +126,44 @@ def compute(self) -> Union[Tensor, Sequence[Tensor]]:
self.sum_squared_target,
self.multioutput,
)

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis
Returns:
Figure and Axes object
Raises:
ModuleNotFoundError:
If `matplotlib` is not installed
.. plot::
:scale: 75
>>> from torch import randn
>>> # Example plotting a single value
>>> from torchmetrics.regression import ExplainedVariance
>>> metric = ExplainedVariance()
>>> metric.update(randn(10,), randn(10,))
>>> fig_, ax_ = metric.plot()
.. plot::
:scale: 75
>>> from torch import randn
>>> # Example plotting multiple values
>>> from torchmetrics.regression import ExplainedVariance
>>> metric = ExplainedVariance()
>>> values = []
>>> for _ in range(10):
... values.append(metric(randn(10,), randn(10,)))
>>> fig, ax = metric.plot(values)
"""
return self._plot(val, ax)
48 changes: 47 additions & 1 deletion src/torchmetrics/regression/kendall.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, List, Optional, Tuple, Union
from typing import Any, List, Optional, Sequence, Tuple, Union

from torch import Tensor
from typing_extensions import Literal
Expand All @@ -25,6 +25,11 @@
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.data import dim_zero_cat
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["KendallRankCorrCoef.plot"]


class KendallRankCorrCoef(Metric):
Expand Down Expand Up @@ -155,3 +160,44 @@ def compute(self) -> Union[Tensor, Tuple[Tensor, Tensor]]:
if p_value is not None:
return tau, p_value
return tau

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis
Returns:
Figure and Axes object
Raises:
ModuleNotFoundError:
If `matplotlib` is not installed
.. plot::
:scale: 75
>>> from torch import randn
>>> # Example plotting a single value
>>> from torchmetrics.regression import KendallRankCorrCoef
>>> metric = KendallRankCorrCoef()
>>> metric.update(randn(10,), randn(10,))
>>> fig_, ax_ = metric.plot()
.. plot::
:scale: 75
>>> from torch import randn
>>> # Example plotting multiple values
>>> from torchmetrics.regression import KendallRankCorrCoef
>>> metric = KendallRankCorrCoef()
>>> values = []
>>> for _ in range(10):
... values.append(metric(randn(10,), randn(10,)))
>>> fig, ax = metric.plot(values)
"""
return self._plot(val, ax)
48 changes: 47 additions & 1 deletion src/torchmetrics/regression/kl_divergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from typing import Any, Optional, Sequence, Union

import torch
from torch import Tensor
Expand All @@ -20,6 +20,11 @@
from torchmetrics.functional.regression.kl_divergence import _kld_compute, _kld_update
from torchmetrics.metric import Metric
from torchmetrics.utilities.data import dim_zero_cat
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["KLDivergence.plot"]


class KLDivergence(Metric):
Expand Down Expand Up @@ -117,3 +122,44 @@ def compute(self) -> Tensor:
else self.measures # type: ignore[assignment]
)
return _kld_compute(measures, self.total, self.reduction)

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis
Returns:
Figure and Axes object
Raises:
ModuleNotFoundError:
If `matplotlib` is not installed
.. plot::
:scale: 75
>>> from torch import randn
>>> # Example plotting a single value
>>> from torchmetrics.regression import KLDivergence
>>> metric = KLDivergence()
>>> metric.update(randn(10,3).softmax(dim=-1), randn(10,3).softmax(dim=-1))
>>> fig_, ax_ = metric.plot()
.. plot::
:scale: 75
>>> from torch import randn
>>> # Example plotting multiple values
>>> from torchmetrics.regression import KLDivergence
>>> metric = KLDivergence()
>>> values = []
>>> for _ in range(10):
... values.append(metric(randn(10,3).softmax(dim=-1), randn(10,3).softmax(dim=-1)))
>>> fig, ax = metric.plot(values)
"""
return self._plot(val, ax)
48 changes: 47 additions & 1 deletion src/torchmetrics/regression/log_cosh.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from typing import Any, Optional, Sequence, Union

import torch
from torch import Tensor

from torchmetrics.functional.regression.log_cosh import _log_cosh_error_compute, _log_cosh_error_update
from torchmetrics.metric import Metric
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["LogCoshError.plot"]


class LogCoshError(Metric):
Expand Down Expand Up @@ -88,3 +93,44 @@ def update(self, preds: Tensor, target: Tensor) -> None:
def compute(self) -> Tensor:
"""Compute LogCosh error over state."""
return _log_cosh_error_compute(self.sum_log_cosh_error, self.total)

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis
Returns:
Figure and Axes object
Raises:
ModuleNotFoundError:
If `matplotlib` is not installed
.. plot::
:scale: 75
>>> from torch import randn
>>> # Example plotting a single value
>>> from torchmetrics.regression import LogCoshError
>>> metric = LogCoshError()
>>> metric.update(randn(10,), randn(10,))
>>> fig_, ax_ = metric.plot()
.. plot::
:scale: 75
>>> from torch import randn
>>> # Example plotting multiple values
>>> from torchmetrics.regression import LogCoshError
>>> metric = LogCoshError()
>>> values = []
>>> for _ in range(10):
... values.append(metric(randn(10,), randn(10,)))
>>> fig, ax = metric.plot(values)
"""
return self._plot(val, ax)
Loading

0 comments on commit 2d21f6f

Please sign in to comment.