Skip to content

Commit

Permalink
Merge branch 'newmetric/move_dice' of https://github.com/Lightning-AI…
Browse files Browse the repository at this point in the history
…/torchmetrics into newmetric/move_dice
  • Loading branch information
SkafteNicki committed Sep 14, 2024
2 parents 9be6172 + 9baf6e6 commit 3950b00
Show file tree
Hide file tree
Showing 25 changed files with 238 additions and 96 deletions.
16 changes: 6 additions & 10 deletions .github/workflows/docs-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,18 +48,11 @@ jobs:
pytorch-version: ${{ matrix.pytorch-version }}
pypi-dir: ${{ env.PYPI_CACHE }}

- name: Install Latex
if: ${{ matrix.target == 'html' }}
# install Texlive, see https://linuxconfig.org/how-to-install-latex-on-ubuntu-20-04-focal-fossa-linux
run: |
sudo apt-get update --fix-missing
sudo apt-get install -y \
texlive-latex-extra texlive-pictures texlive-fonts-recommended dvipng cm-super
- name: Install package & dependencies
run: |
make get-sphinx-template
pip install . -U -r requirements/_docs.txt \
# install with -e so the path to source link comes from this project not from the installed package
pip install -e . -U -r requirements/_docs.txt \
--find-links="${PYPI_CACHE}" --find-links="${TORCH_URL}"
- run: pip list
- name: Full build for deployment
Expand All @@ -70,7 +63,10 @@ jobs:
run: echo "SPHINX_ENABLE_GALLERY=0" >> $GITHUB_ENV
- name: make ${{ matrix.target }}
working-directory: ./docs
run: make ${{ matrix.target }} --debug --jobs $(nproc) SPHINXOPTS="-W --keep-going"
run: |
pwd
ls -la
make ${{ matrix.target }} --debug --jobs $(nproc) SPHINXOPTS="-W --keep-going"
- name: Upload built docs
if: ${{ matrix.target == 'html' && github.event_name != 'pull_request' }}
Expand Down
17 changes: 17 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,33 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Deprecated Dice from classification metrics ([#2725](https://github.com/Lightning-AI/torchmetrics/pull/2725))


### Deprecated

- Deprecated `num_outputs` in `R2Score` ([#2705](https://github.com/Lightning-AI/torchmetrics/pull/2705))


### Removed

-


### Fixed

- Fixed wrong aggregation in `segmentation.MeanIoU` ([#2698](https://github.com/Lightning-AI/torchmetrics/pull/2698))


- Fixed handling zero division error in binary IoU (Jaccard index) calculation ([#2726](https://github.com/Lightning-AI/torchmetrics/pull/2726))


- Correct the padding related calculation errors in SSIM ([#2721](https://github.com/Lightning-AI/torchmetrics/pull/2721))


- Fixed compatibility of audio domain with new `scipy` ([#2733](https://github.com/Lightning-AI/torchmetrics/pull/2733))


- Fixed how `prefix`/`postfix` works in `MultitaskWrapper` ([#2722](https://github.com/Lightning-AI/torchmetrics/pull/2722))


## [1.4.1] - 2024-08-02

### Changed
Expand Down
8 changes: 6 additions & 2 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,11 @@ def _set_root_image_path(page_path: str) -> None:
),
]

# MathJax configuration
mathjax3_config = {
"tex": {"packages": {"[+]": ["ams", "newcommand", "configMacros"]}},
}

# -- Options for Epub output -------------------------------------------------

# Bibliographic Dublin Core info.
Expand Down Expand Up @@ -358,8 +363,7 @@ def package_list_from_file(file: str) -> list[str]:
autodoc_mock_imports = MOCK_PACKAGES


# Resolve function
# This function is used to populate the (source) links in the API
# Resolve function - this function is used to populate the (source) links in the API
def linkcode_resolve(domain, info) -> Optional[str]: # noqa: ANN001
return _linkcode_resolve(domain, info=info, github_user="Lightning-AI", github_repo="torchmetrics")

Expand Down
2 changes: 1 addition & 1 deletion docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@
.. _CER: https://rechtsprechung-im-ostseeraum.archiv.uni-greifswald.de/word-error-rate-character-error-rate-how-to-evaluate-a-model
.. _MER: https://www.isca-speech.org/archive/interspeech_2004/morris04_interspeech.html
.. _WIL: https://www.isca-speech.org/archive/interspeech_2004/morris04_interspeech.html
.. _WIP: https://infoscience.epfl.ch/entities/publication/9983d013-8239-422e-a3f7-a1500d309474
.. _WIP: https://www.isca-archive.org/interspeech_2004/morris04_interspeech.pdf
.. _TV: https://en.wikipedia.org/wiki/Total_variation_denoising
.. _InfoLM: https://arxiv.org/abs/2112.01589
.. _alpha divergence: https://static.renyi.hu/renyi_cikkek/1961_on_measures_of_entropy_and_information.pdf
Expand Down
2 changes: 1 addition & 1 deletion requirements/classification_test.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment

pandas >=1.4.0, <=2.2.2
pandas >1.4.0, <=2.2.2
netcal >1.0.0, <1.4.0 # calibration_error
numpy <2.2.0
fairlearn # group_fairness
2 changes: 1 addition & 1 deletion requirements/detection_test.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment

faster-coco-eval >=1.3.3
faster-coco-eval ==1.5.*
4 changes: 2 additions & 2 deletions requirements/nominal_test.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment

pandas >1.0.0, <=2.2.2 # cannot pin version due to numpy version incompatibility
dython <=0.7.7
pandas >1.4.0, <=2.2.2 # cannot pin version due to numpy version incompatibility
dython ~=0.7.6
scipy >1.0.0, <1.15.0 # cannot pin version due to some version conflicts with `oldest` CI configuration
statsmodels >0.13.5, <0.15.0
2 changes: 1 addition & 1 deletion requirements/typing.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
mypy ==1.11.2
torch ==2.4.0
torch ==2.4.1

types-PyYAML
types-emoji
Expand Down
2 changes: 1 addition & 1 deletion requirements/visual.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment

matplotlib >=3.3.0, <3.10.0
matplotlib >=3.6.0, <3.10.0
SciencePlots >= 2.0.0, <2.2.0
7 changes: 7 additions & 0 deletions src/torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@
if not hasattr(PIL, "PILLOW_VERSION"):
PIL.PILLOW_VERSION = PIL.__version__

if package_available("scipy"):
import scipy.signal

# back compatibility patch due to SMRMpy using scipy.signal.hamming
if not hasattr(scipy.signal, "hamming"):
scipy.signal.hamming = scipy.signal.windows.hamming

from torchmetrics import functional # noqa: E402
from torchmetrics.aggregation import ( # noqa: E402
CatMetric,
Expand Down
6 changes: 3 additions & 3 deletions src/torchmetrics/classification/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,9 @@ class Dice(Metric):
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
.. warning::
The `dice` metrics is being deprecated from the classification subpackage in v1.6.0 of torchmetrics and will be
removed in v1.7.0. Please instead consider using `f1score` metric from the classification subpackage as it
provides the same functionality. Additionally, we are going to re-add the `dice` metric in the segmentation
The ``dice`` metrics is being deprecated from the classification subpackage in v1.6.0 of torchmetrics and will be
removed in v1.7.0. Please instead consider using ``f1score`` metric from the classification subpackage as it
provides the same functionality. Additionally, we are going to re-add the ``dice`` metric in the segmentation
domain in v1.6.0 with slight modifications to functionality.
Raises:
Expand Down
6 changes: 3 additions & 3 deletions src/torchmetrics/functional/classification/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,9 @@ def dice(
than what they appear to be.
.. warning::
The `dice` metrics is being deprecated from the classification subpackage in v1.6.0 of torchmetrics and will be
removed in v1.7.0. Please instead consider using `f1score` metric from the classification subpackage as it
provides the same functionality. Additionally, we are going to re-add the `dice` metric in the segmentation
The ``dice`` metrics is being deprecated from the classification subpackage in v1.6.0 of torchmetrics and will be
removed in v1.7.0. Please instead consider using ``f1score`` metric from the classification subpackage as it
provides the same functionality. Additionally, we are going to re-add the ``dice`` metric in the segmentation
domain in v1.6.0 with slight modifications to functionality.
Return:
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/functional/classification/jaccard.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def _jaccard_index_reduce(
raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.")
confmat = confmat.float()
if average == "binary":
return confmat[1, 1] / (confmat[0, 1] + confmat[1, 0] + confmat[1, 1])
return _safe_divide(confmat[1, 1], (confmat[0, 1] + confmat[1, 0] + confmat[1, 1]), zero_division=zero_division)

ignore_index_cond = ignore_index is not None and 0 <= ignore_index < confmat.shape[0]
multilabel = confmat.ndim == 3
Expand Down
1 change: 1 addition & 0 deletions src/torchmetrics/functional/nominal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from torchmetrics.functional.nominal.cramers import cramers_v, cramers_v_matrix
from torchmetrics.functional.nominal.fleiss_kappa import fleiss_kappa
from torchmetrics.functional.nominal.pearson import (
Expand Down
12 changes: 6 additions & 6 deletions src/torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
"""Aggregate and evaluate batch input directly.
Serves the dual purpose of both computing the metric on the current batch of inputs but also add the batch
statistics to the overall accumululating metric state. Input arguments are the exact same as corresponding
statistics to the overall accumulating metric state. Input arguments are the exact same as corresponding
``update`` method. The returned output is the exact same as the output of ``compute``.
Args:
Expand Down Expand Up @@ -361,7 +361,7 @@ def _forward_full_state_update(self, *args: Any, **kwargs: Any) -> Any:
def _forward_reduce_state_update(self, *args: Any, **kwargs: Any) -> Any:
"""Forward computation using single call to `update`.
This can be done when the global metric state is a sinple reduction of batch states. This can be unsafe for
This can be done when the global metric state is a simple reduction of batch states. This can be unsafe for
certain metric cases but is also the fastest way to both accumulate globally and compute locally.
"""
Expand Down Expand Up @@ -802,7 +802,7 @@ def _apply(self, fn: Callable, exclude_state: Sequence[str] = "") -> Module:
"""Overwrite `_apply` function such that we can also move metric states to the correct device.
This method is called by the base ``nn.Module`` class whenever `.to`, `.cuda`, `.float`, `.half` etc. methods
are called. Dtype conversion is garded and will only happen through the special `set_dtype` method.
are called. Dtype conversion is guarded and will only happen through the special `set_dtype` method.
Args:
fn: the function to apply
Expand Down Expand Up @@ -1166,15 +1166,15 @@ def _sync_dist(self, dist_sync_fn: Optional[Callable] = None, process_group: Opt
"""

def update(self, *args: Any, **kwargs: Any) -> None:
"""Redirect the call to the input which the conposition was formed from."""
"""Redirect the call to the input which the composition was formed from."""
if isinstance(self.metric_a, Metric):
self.metric_a.update(*args, **self.metric_a._filter_kwargs(**kwargs))

if isinstance(self.metric_b, Metric):
self.metric_b.update(*args, **self.metric_b._filter_kwargs(**kwargs))

def compute(self) -> Any:
"""Redirect the call to the input which the conposition was formed from."""
"""Redirect the call to the input which the composition was formed from."""
# also some parsing for kwargs?
val_a = self.metric_a.compute() if isinstance(self.metric_a, Metric) else self.metric_a
val_b = self.metric_b.compute() if isinstance(self.metric_b, Metric) else self.metric_b
Expand Down Expand Up @@ -1216,7 +1216,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
return self._forward_cache

def reset(self) -> None:
"""Redirect the call to the input which the conposition was formed from."""
"""Redirect the call to the input which the composition was formed from."""
if isinstance(self.metric_a, Metric):
self.metric_a.reset()

Expand Down
1 change: 1 addition & 0 deletions src/torchmetrics/nominal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from torchmetrics.nominal.cramers import CramersV
from torchmetrics.nominal.fleiss_kappa import FleissKappa
from torchmetrics.nominal.pearson import PearsonsContingencyCoefficient
Expand Down
47 changes: 31 additions & 16 deletions src/torchmetrics/regression/r2.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
# limitations under the License.
from typing import Any, Optional, Sequence, Union

import torch
from torch import Tensor, tensor

from torchmetrics.functional.regression.r2 import _r2_score_compute, _r2_score_update
from torchmetrics.metric import Metric
from torchmetrics.utilities import rank_zero_warn
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

Expand Down Expand Up @@ -65,23 +65,32 @@ class R2Score(Metric):
* ``'variance_weighted'`` scores are weighted by their individual variances
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
.. warning::
Argument ``num_outputs`` in ``R2Score`` has been deprecated because it is no longer necessary and will be
removed in v1.6.0 of TorchMetrics. The number of outputs is now automatically inferred from the shape
of the input tensors.
Raises:
ValueError:
If ``adjusted`` parameter is not an integer larger or equal to 0.
ValueError:
If ``multioutput`` is not one of ``"raw_values"``, ``"uniform_average"`` or ``"variance_weighted"``.
Example:
Example (single output):
>>> from torch import tensor
>>> from torchmetrics.regression import R2Score
>>> target = torch.tensor([3, -0.5, 2, 7])
>>> preds = torch.tensor([2.5, 0.0, 2, 8])
>>> target = tensor([3, -0.5, 2, 7])
>>> preds = tensor([2.5, 0.0, 2, 8])
>>> r2score = R2Score()
>>> r2score(preds, target)
tensor(0.9486)
>>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]])
>>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]])
>>> r2score = R2Score(num_outputs=2, multioutput='raw_values')
Example (multioutput):
>>> from torch import tensor
>>> from torchmetrics.regression import R2Score
>>> target = tensor([[0.5, 1], [-1, 1], [7, -6]])
>>> preds = tensor([[0, 2], [-1, 2], [8, -5]])
>>> r2score = R2Score(multioutput='raw_values')
>>> r2score(preds, target)
tensor([0.9654, 0.9082])
Expand All @@ -100,14 +109,20 @@ class R2Score(Metric):

def __init__(
self,
num_outputs: int = 1,
num_outputs: Optional[int] = None,
adjusted: int = 0,
multioutput: str = "uniform_average",
**kwargs: Any,
) -> None:
super().__init__(**kwargs)

self.num_outputs = num_outputs
if num_outputs is not None:
rank_zero_warn(
"Argument `num_outputs` in `R2Score` has been deprecated because it is no longer necessary and will be"
"removed in v1.6.0 of TorchMetrics. The number of outputs is now automatically inferred from the shape"
"of the input tensors.",
DeprecationWarning,
)

if adjusted < 0 or not isinstance(adjusted, int):
raise ValueError("`adjusted` parameter should be an integer larger or equal to 0.")
Expand All @@ -120,19 +135,19 @@ def __init__(
)
self.multioutput = multioutput

self.add_state("sum_squared_error", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum")
self.add_state("sum_error", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum")
self.add_state("residual", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum")
self.add_state("sum_squared_error", default=tensor(0.0), dist_reduce_fx="sum")
self.add_state("sum_error", default=tensor(0.0), dist_reduce_fx="sum")
self.add_state("residual", default=tensor(0.0), dist_reduce_fx="sum")
self.add_state("total", default=tensor(0), dist_reduce_fx="sum")

def update(self, preds: Tensor, target: Tensor) -> None:
"""Update state with predictions and targets."""
sum_squared_error, sum_error, residual, total = _r2_score_update(preds, target)

self.sum_squared_error += sum_squared_error
self.sum_error += sum_error
self.residual += residual
self.total += total
self.sum_squared_error = self.sum_squared_error + sum_squared_error
self.sum_error = self.sum_error + sum_error
self.residual = self.residual + residual
self.total = self.total + total

def compute(self) -> Tensor:
"""Compute r2 score over the metric states."""
Expand Down
8 changes: 5 additions & 3 deletions src/torchmetrics/segmentation/mean_iou.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ def __init__(
self.input_format = input_format

num_classes = num_classes - 1 if not include_background else num_classes
self.add_state("score", default=torch.zeros(num_classes if per_class else 1), dist_reduce_fx="mean")
self.add_state("score", default=torch.zeros(num_classes if per_class else 1), dist_reduce_fx="sum")
self.add_state("num_batches", default=torch.tensor(0), dist_reduce_fx="sum")

def update(self, preds: Tensor, target: Tensor) -> None:
"""Update the state with the new data."""
Expand All @@ -119,10 +120,11 @@ def update(self, preds: Tensor, target: Tensor) -> None:
)
score = _mean_iou_compute(intersection, union, per_class=self.per_class)
self.score += score.mean(0) if self.per_class else score.mean()
self.num_batches += 1

def compute(self) -> Tensor:
"""Update the state with the new data."""
return self.score # / self.num_batches
"""Compute the final Mean Intersection over Union (mIoU)."""
return self.score / self.num_batches

def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Expand Down
1 change: 1 addition & 0 deletions src/torchmetrics/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
_MECAB_KO_DIC_AVAILABLE = RequirementCache("mecab_ko_dic")
_IPADIC_AVAILABLE = RequirementCache("ipadic")
_SENTENCEPIECE_AVAILABLE = RequirementCache("sentencepiece")
_SCIPI_AVAILABLE = RequirementCache("scipy")
_SKLEARN_GREATER_EQUAL_1_3 = RequirementCache("scikit-learn>=1.3.0")

_LATEX_AVAILABLE: bool = shutil.which("latex") is not None
Loading

0 comments on commit 3950b00

Please sign in to comment.