Skip to content

Commit

Permalink
Newmetric: ProcrustesDisparity (#2723)
Browse files Browse the repository at this point in the history
* some docs

* some code

* docs

* nearly working code

* improve tests

* fix src

* changelog

* fix bare except

* fix doctest

* fix mypy

* Update src/torchmetrics/functional/shape/procrustes.py

* Update src/torchmetrics/functional/shape/procrustes.py

* Update src/torchmetrics/functional/shape/procrustes.py

* Update src/torchmetrics/shape/procrustes.py

* Apply suggestions from code review

* rename input variables

---------

Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Oct 11, 2024
1 parent c2c68b6 commit a15ef9a
Show file tree
Hide file tree
Showing 12 changed files with 385 additions and 1 deletion.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added multi-output support for MAE metric ([#2605](https://github.com/Lightning-AI/torchmetrics/pull/2605))


- Added new metric `ProcrustesDistance` to new domain Shape ([#2723](https://github.com/Lightning-AI/torchmetrics/pull/2723)


### Changed

- Tracker higher is better integration ([#2649](https://github.com/Lightning-AI/torchmetrics/pull/2649))
Expand Down
8 changes: 8 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,14 @@ Or directly from conda

segmentation/*

.. toctree::
:maxdepth: 2
:name: shape
:caption: Shape
:glob:

shape/*

.. toctree::
:maxdepth: 2
:name: text
Expand Down
1 change: 1 addition & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,4 @@
.. _averaging curve objects: https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html
.. _SCC: https://www.ingentaconnect.com/content/tandf/tres/1998/00000019/00000004/art00013
.. _Generalized Dice Score: https://arxiv.org/abs/1707.03237
.. _Procrustes Disparity: https://en.wikipedia.org/wiki/Procrustes_analysis
2 changes: 1 addition & 1 deletion docs/source/segmentation/mean_iou.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
.. customcarditem::
:header: Mean Intersection over Union (mIoU)
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/object_detection.svg
:tags: segmentation
:tags: Segmentation

###################################
Mean Intersection over Union (mIoU)
Expand Down
22 changes: 22 additions & 0 deletions docs/source/shape/procrustes.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
.. customcarditem::
:header: Procrustes Disparity
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/tabular_classification.svg
:tags: shape

.. include:: ../links.rst

####################
Procrustes Disparity
####################

Module Interface
________________

.. autoclass:: torchmetrics.shape.ProcrustesDisparity
:exclude-members: update, compute


Functional Interface
____________________

.. autofunction:: torchmetrics.functional.shape.procrustes_disparity
16 changes: 16 additions & 0 deletions src/torchmetrics/functional/shape/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.functional.shape.procrustes import procrustes_disparity

__all__ = ["procrustes_disparity"]
66 changes: 66 additions & 0 deletions src/torchmetrics/functional/shape/procrustes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple, Union

import torch
from torch import Tensor, linalg

from torchmetrics.utilities.checks import _check_same_shape
from torchmetrics.utilities.prints import rank_zero_warn


def procrustes_disparity(
point_cloud1: Tensor, point_cloud2: Tensor, return_all: bool = False
) -> Union[Tensor, Tuple[Tensor, Tensor, Tensor]]:
"""Runs procrustrus analysis on a batch of data points.
Works similar ``scipy.spatial.procrustes`` but for batches of data points.
Args:
point_cloud1: The first set of data points
point_cloud2: The second set of data points
return_all: If True, returns the scale and rotation matrices along with the disparity
"""
_check_same_shape(point_cloud1, point_cloud2)
if point_cloud1.ndim != 3:
raise ValueError(
"Expected both datasets to be 3D tensors of shape (N, M, D), where N is the batch size, M is the number of"
f" data points and D is the dimensionality of the data points, but got {point_cloud1.ndim} dimensions."
)

point_cloud1 = point_cloud1 - point_cloud1.mean(dim=1, keepdim=True)
point_cloud2 = point_cloud2 - point_cloud2.mean(dim=1, keepdim=True)
point_cloud1 /= linalg.norm(point_cloud1, dim=[1, 2], keepdim=True)
point_cloud2 /= linalg.norm(point_cloud2, dim=[1, 2], keepdim=True)

try:
u, w, v = linalg.svd(
torch.matmul(point_cloud2.transpose(1, 2), point_cloud1).transpose(1, 2), full_matrices=False
)
except Exception as ex:
rank_zero_warn(
f"SVD calculation in procrustes_disparity failed with exception {ex}. Returning 0 disparity and identity"
" scale/rotation.",
UserWarning,
)
return torch.tensor(0.0), torch.ones(point_cloud1.shape[0]), torch.eye(point_cloud1.shape[2])

rotation = torch.matmul(u, v)
scale = w.sum(1, keepdim=True)
point_cloud2 = scale[:, None] * torch.matmul(point_cloud2, rotation.transpose(1, 2))
disparity = (point_cloud1 - point_cloud2).square().sum(dim=[1, 2])
if return_all:
return disparity, scale, rotation
return disparity
16 changes: 16 additions & 0 deletions src/torchmetrics/shape/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.shape.procrustes import ProcrustesDisparity

__all__ = ["ProcrustesDisparity"]
137 changes: 137 additions & 0 deletions src/torchmetrics/shape/procrustes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, Sequence, Union

import torch
from torch import Tensor
from typing_extensions import Literal

from torchmetrics import Metric
from torchmetrics.functional.shape.procrustes import procrustes_disparity
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["ProcrustesDisparity.plot"]


class ProcrustesDisparity(Metric):
r"""Compute the `Procrustes Disparity`_.
The Procrustes Disparity is defined as the sum of the squared differences between two datasets after
applying a Procrustes transformation. The Procrustes Disparity is useful to compare two datasets
that are similar but not aligned.
The metric works similar to ``scipy.spatial.procrustes`` but for batches of data points. The disparity is
aggregated over the batch, thus to get the individual disparities please use the functional version of this
metric: ``torchmetrics.functional.shape.procrustes.procrustes_disparity``.
As input to ``forward`` and ``update`` the metric accepts the following input:
- ``point_cloud1`` (torch.Tensor): A tensor of shape ``(N, M, D)`` with ``N`` being the batch size,
``M`` the number of data points and ``D`` the dimensionality of the data points.
- ``point_cloud2`` (torch.Tensor): A tensor of shape ``(N, M, D)`` with ``N`` being the batch size,
``M`` the number of data points and ``D`` the dimensionality of the data points.
As output to ``forward`` and ``compute`` the metric returns the following output:
- ``gds`` (:class:`~torch.Tensor`): A scalar tensor with the Procrustes Disparity.
Args:
reduction: Determines whether to return the mean disparity or the sum of the disparities.
Can be one of ``"mean"`` or ``"sum"``.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Raises:
ValueError: If ``average`` is not one of ``"mean"`` or ``"sum"``.
Example:
>>> from torch import randn
>>> from torchmetrics.shape import ProcrustesDisparity
>>> metric = ProcrustesDisparity()
>>> point_cloud1 = randn(10, 50, 2)
>>> point_cloud2 = randn(10, 50, 2)
>>> metric(point_cloud1, point_cloud2)
tensor(0.9770)
"""

disparity: Tensor
total: Tensor
full_state_update: bool = False
is_differentiable: bool = False
higher_is_better: bool = False
plot_lower_bound: float = 0.0
plot_upper_bound: float = 1.0

def __init__(self, reduction: Literal["mean", "sum"] = "mean", **kwargs: Any) -> None:
super().__init__(**kwargs)
if reduction not in ("mean", "sum"):
raise ValueError(f"Argument `reduction` must be one of ['mean', 'sum'], got {reduction}")
self.reduction = reduction
self.add_state("disparity", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

def update(self, point_cloud1: torch.Tensor, point_cloud2: torch.Tensor) -> None:
"""Update the Procrustes Disparity with the given datasets."""
disparity: Tensor = procrustes_disparity(point_cloud1, point_cloud2) # type: ignore[assignment]
self.disparity += disparity.sum()
self.total += disparity.numel()

def compute(self) -> torch.Tensor:
"""Computes the Procrustes Disparity."""
if self.reduction == "mean":
return self.disparity / self.total
return self.disparity

def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis
Returns:
Figure and Axes object
Raises:
ModuleNotFoundError:
If `matplotlib` is not installed
.. plot::
:scale: 75
>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.shape import ProcrustesDisparity
>>> metric = ProcrustesDisparity()
>>> metric.update(torch.randn(10, 50, 2), torch.randn(10, 50, 2))
>>> fig_, ax_ = metric.plot()
.. plot::
:scale: 75
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.shape import ProcrustesDisparity
>>> metric = ProcrustesDisparity()
>>> values = [ ]
>>> for _ in range(10):
... values.append(metric(torch.randn(10, 50, 2), torch.randn(10, 50, 2)))
>>> fig_, ax_ = metric.plot(values)
"""
return self._plot(val, ax)
13 changes: 13 additions & 0 deletions tests/unittests/shape/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Loading

0 comments on commit a15ef9a

Please sign in to comment.