Skip to content

Commit

Permalink
Merge branch 'master' into optional_num_classes
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda authored Jan 8, 2025
2 parents 255db5f + e690bbd commit 03f70ae
Show file tree
Hide file tree
Showing 224 changed files with 499 additions and 465 deletions.
23 changes: 19 additions & 4 deletions .github/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,11 +132,24 @@ def changed_domains(
if not files:
logging.debug("Only integrations was changed so not reason for deep testing...")
return _return_empty

# filter only docs files
files_docs = [fn for fn in files if fn.startswith("docs")]
if len(files) == len(files_docs):
logging.debug("Only docs was changed so not reason for deep testing...")
return _return_empty

files_markdown = [fn for fn in files if fn.endswith(".md")]
if len(files) == len(files_markdown):
logging.debug("Only markdown files was changed so not reason for deep testing...")
return _return_empty

# filter only testing files which are not specific tests so for example configurations or helper tools
files_testing = [fn for fn in files if fn.startswith("tests") and not fn.endswith(".md") and "test_" not in fn]
if files_testing:
logging.debug("Some testing files was changed -> rather test everything...")
return _return_all

# files in requirements folder
files_req = [fn for fn in files if fn.startswith("requirements")]
req_domains = [fn.split("/")[1] for fn in files_req]
Expand All @@ -147,19 +160,21 @@ def changed_domains(
return _return_all

# filter only package files and skip inits
_is_in_test = lambda fn: fn.startswith("tests")
_filter_pkg = lambda fn: _is_in_test(fn) or (fn.startswith("src/torchmetrics") and "__init__.py" not in fn)
_is_in_test = lambda fname: fname.startswith("tests")
_filter_pkg = lambda fname: _is_in_test(fname) or (
fname.startswith("src/torchmetrics") and "__init__.py" not in fname
)
files_pkg = [fn for fn in files if _filter_pkg(fn)]
if not files_pkg:
return _return_all

# parse domains
def _crop_path(fname: str, paths: list[str]) -> str:
def _crop_path(fname: str, paths: tuple[str] = ("src/torchmetrics/", "tests/unittests/", "functional/")) -> str:
for p in paths:
fname = fname.replace(p, "")
return fname

files_pkg = [_crop_path(fn, ["src/torchmetrics/", "tests/unittests/", "functional/"]) for fn in files_pkg]
files_pkg = [_crop_path(fn) for fn in files_pkg]
# filter domain names
tm_modules = [fn.split("/")[0] for fn in files_pkg if "/" in fn]
# filter general (used everywhere) sub-packages
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/ci-integrate.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:
- { python-version: "3.10", requires: "latest", os: "ubuntu-22.04" }
# - { python-version: "3.10", requires: "latest", os: "macOS-14" } # M1 machine # todo: crashing for MPS out of memory
env:
PYTORCH_URL: "https://download.pytorch.org/whl/cpu/torch_stable.html"
PYTORCH_URL: "https://download.pytorch.org/whl/cpu/"
FREEZE_REQUIREMENTS: ${{ ! (github.ref == 'refs/heads/master' || startsWith(github.ref, 'refs/heads/release/')) }}
PYPI_CACHE: "_ci-cache_PyPI"

Expand Down Expand Up @@ -64,7 +64,7 @@ jobs:
# this was updated in `source cashing` by optional oldest
cat requirements/_integrate.txt
# to have install pyTorch
pip install -e . "setuptools==69.5.1" --find-links=${PYTORCH_URL}
pip install -e . "setuptools==69.5.1" --extra-index-url=${PYTORCH_URL}
# adjust version to PT ecosystem based on installed TM
python -m wget https://raw.githubusercontent.com/Lightning-AI/utilities/main/scripts/adjust-torch-versions.py
Expand All @@ -73,7 +73,7 @@ jobs:
# install package and dependencies
pip install -e . -r requirements/_tests.txt -r requirements/_integrate.txt \
--find-links=${PYTORCH_URL} --find-links=${PYPI_CACHE} \
--extra-index-url="${PYTORCH_URL}" --find-links="${PYPI_CACHE}" \
--upgrade-strategy eager
pip list
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ci-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ jobs:
FREEZE_REQUIREMENTS: ${{ ! (github.ref == 'refs/heads/master' || startsWith(github.ref, 'refs/heads/release/')) }}
TOKENIZERS_PARALLELISM: false
TEST_DIRS: ${{ needs.check-diff.outputs.test-dirs }}
PIP_EXTRA_INDEX_URL: "--find-links=https://download.pytorch.org/whl/cpu/torch_stable.html"
PIP_EXTRA_INDEX_URL: "--extra-index-url=https://download.pytorch.org/whl/cpu/"
UNITTEST_TIMEOUT: "" # by default, it is not set

# Timeout: https://stackoverflow.com/a/59076067/4521646
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/docs-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ defaults:

env:
FREEZE_REQUIREMENTS: "1"
TORCH_URL: "https://download.pytorch.org/whl/cpu/torch_stable.html"
TORCH_URL: "https://download.pytorch.org/whl/cpu/"
PYPI_CACHE: "_ci-cache_PyPI"
PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION: "python"
TOKENIZERS_PARALLELISM: false
Expand All @@ -39,7 +39,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.9"
python-version: "3.x"

- name: source cashing
uses: ./.github/actions/pull-caches
Expand All @@ -52,7 +52,7 @@ jobs:
make get-sphinx-template
# install with -e so the path to source link comes from this project not from the installed package
pip install -e . -U -r requirements/_docs.txt \
--find-links="${PYPI_CACHE}" --find-links="${TORCH_URL}"
--find-links="${PYPI_CACHE}" --extra-index-url="${TORCH_URL}"
- run: pip list
- name: Full build for deployment
if: github.event_name != 'pull_request'
Expand Down
17 changes: 8 additions & 9 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ ci:

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
rev: v5.0.0
hooks:
- id: end-of-file-fixer
- id: trailing-whitespace
Expand All @@ -46,11 +46,10 @@ repos:
exclude: pyproject.toml

- repo: https://github.com/crate-ci/typos
rev: v1.22.9
rev: dictgen-v0.3.1
hooks:
- id: typos
# empty to do not write fixes
args: []
args: [] # empty to do not write fixes
exclude: pyproject.toml

- repo: https://github.com/PyCQA/docformatter
Expand All @@ -61,12 +60,12 @@ repos:
args: ["--in-place"]

- repo: https://github.com/sphinx-contrib/sphinx-lint
rev: v0.9.1
rev: v1.0.0
hooks:
- id: sphinx-lint

- repo: https://github.com/executablebooks/mdformat
rev: 0.7.17
rev: 0.7.21
hooks:
- id: mdformat
args: ["--number"]
Expand Down Expand Up @@ -113,7 +112,7 @@ repos:
- id: text-unicode-replacement-char

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.5.0
rev: v0.8.6
hooks:
# try to fix what is possible
- id: ruff
Expand All @@ -124,11 +123,11 @@ repos:
- id: ruff

- repo: https://github.com/tox-dev/pyproject-fmt
rev: 2.1.3
rev: v2.5.0
hooks:
- id: pyproject-fmt
additional_dependencies: [tox]
- repo: https://github.com/abravalheri/validate-pyproject
rev: v0.18
rev: v0.23
hooks:
- id: validate-pyproject
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -842,7 +842,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
* `PearsonCorrcoef`
* `SpearmanCorrcoef`
- Removed deprecated functions, and warnings in detection and pairwise ([#804](https://github.com/Lightning-AI/metrics/pull/804))
* `MAP` and `functional.pairwise.manhatten`
* `MAP` and `functional.pairwise.manhattan`
- Removed deprecated functions, and warnings in Audio ([#805](https://github.com/Lightning-AI/metrics/pull/805))
* `PESQ` and `functional.audio.pesq`
* `PIT` and `functional.audio.pit`
Expand Down Expand Up @@ -1032,7 +1032,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `pairwise_cosine_similarity`
- `pairwise_euclidean_distance`
- `pairwise_linear_similarity`
- `pairwise_manhatten_distance`
- `pairwise_manhattan_distance`

### Changed

Expand Down
1 change: 1 addition & 0 deletions _samples/bert_score-own_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import torch
from torch import Tensor, nn
from torch.nn import Module

from torchmetrics.text.bert import BERTScore

_NUM_LAYERS = 2
Expand Down
1 change: 1 addition & 0 deletions _samples/detection_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""An example of how the predictions and target should be defined for the MAP object detection metric."""

from torch import BoolTensor, IntTensor, Tensor

from torchmetrics.detection.mean_ap import MeanAveragePrecision

# Preds should be a list of elements, where each element is a dict
Expand Down
3 changes: 2 additions & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
from typing import Optional

import lai_sphinx_theme
import torchmetrics
from lightning_utilities.docs.formatting import _linkcode_resolve, _transform_changelog

import torchmetrics

_PATH_HERE = os.path.abspath(os.path.dirname(__file__))
_PATH_ROOT = os.path.realpath(os.path.join(_PATH_HERE, "..", ".."))
sys.path.insert(0, os.path.abspath(_PATH_ROOT))
Expand Down
2 changes: 1 addition & 1 deletion docs/source/pages/implement.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ We provide the remaining interface, such as ``reset()`` that will make sure to c
states that have been added using ``add_state``. You should therefore not implement ``reset()`` yourself, only in rare
cases where not all the state variables should be reset to their default value. Adding metric states with ``add_state``
will make sure that states are correctly synchronized in distributed settings (DDP). To see how metric states are
synchronized across distributed processes, refer to :meth:`~torchmetrics.Metric.add_state()` docs from the base
synchronized across distributed processes, refer to :meth:`~torchmetrics.Metric.add_state` docs from the base
:class:`~torchmetrics.Metric` class.

Below is a basic implementation of a custom accuracy metric. In the ``__init__`` method we add the metric states
Expand Down
1 change: 1 addition & 0 deletions docs/source/pyplots/binary_accuracy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import matplotlib.pyplot as plt
import torch

import torchmetrics

N = 10
Expand Down
1 change: 1 addition & 0 deletions docs/source/pyplots/binary_accuracy_multistep.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import matplotlib.pyplot as plt
import torch

import torchmetrics

N = 10
Expand Down
1 change: 1 addition & 0 deletions docs/source/pyplots/collection_binary.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import matplotlib.pyplot as plt
import torch

import torchmetrics

N = 10
Expand Down
1 change: 1 addition & 0 deletions docs/source/pyplots/collection_binary_together.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import matplotlib.pyplot as plt
import torch

import torchmetrics

N = 10
Expand Down
1 change: 1 addition & 0 deletions docs/source/pyplots/confusion_matrix.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import matplotlib.pyplot as plt
import torch

import torchmetrics

N = 10
Expand Down
1 change: 1 addition & 0 deletions docs/source/pyplots/multiclass_accuracy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import matplotlib.pyplot as plt
import torch

import torchmetrics

N = 10
Expand Down
1 change: 1 addition & 0 deletions docs/source/pyplots/tracker_binary.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import matplotlib.pyplot as plt
import torch

import torchmetrics

N = 10
Expand Down
1 change: 1 addition & 0 deletions examples/audio/pesq.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import numpy as np
import torch
import torchaudio

from torchmetrics.audio import PerceptualEvaluationSpeechQuality

# %%
Expand Down
1 change: 1 addition & 0 deletions examples/audio/signal_to_noise_ratio.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import matplotlib.pyplot as plt
import numpy as np
import torch

from torchmetrics.audio import SignalNoiseRatio

# %%
Expand Down
1 change: 1 addition & 0 deletions examples/image/clip_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import torch
from matplotlib.table import Table
from skimage.data import astronaut, cat, coffee

from torchmetrics.multimodal import CLIPScore

# %%
Expand Down
1 change: 1 addition & 0 deletions examples/image/spatial_correlation_coef.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import torch
from skimage.data import shepp_logan_phantom
from skimage.transform import iradon, radon, rescale

from torchmetrics.image import SpatialCorrelationCoefficient

# %%
Expand Down
3 changes: 2 additions & 1 deletion examples/text/bertscore.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
Let's consider a use case in natural language processing where BERTScore is used to evaluate the quality of a text generation model. In this case we are imaging that we are developing a automated news summarization system. The goal is to create concise summaries of news articles that accurately capture the key points of the original articles. To evaluate the performance of your summarization system, you need a metric that can compare the generated summaries to human-written summaries. This is where the BERTScore can be used.
"""

from torchmetrics.text import BERTScore, ROUGEScore
from transformers import AutoTokenizer, pipeline

from torchmetrics.text import BERTScore, ROUGEScore

pipe = pipeline("text-generation", model="openai-community/gpt2")
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")

Expand Down
3 changes: 2 additions & 1 deletion examples/text/perplexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
# Here's a hypothetical Python example demonstrating the usage of Perplexity to evaluate a generative language model

import torch
from torchmetrics.text import Perplexity
from transformers import AutoModelWithLMHead, AutoTokenizer

from torchmetrics.text import Perplexity

# %%
# Load the GPT-2 model and tokenizer

Expand Down
3 changes: 2 additions & 1 deletion examples/text/rouge.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
# %%
# Here's a hypothetical Python example demonstrating the usage of unigram ROUGE F-score to evaluate a generative language model:

from torchmetrics.text import ROUGEScore
from transformers import AutoTokenizer, pipeline

from torchmetrics.text import ROUGEScore

pipe = pipeline("text-generation", model="openai-community/gpt2")
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")

Expand Down
2 changes: 1 addition & 1 deletion requirements/image.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment

scipy >1.0.0, <1.15.0
scipy >1.0.0, <1.16.0
torchvision >=0.15.1, <0.22.0
torch-fidelity <=0.4.0 # bumping to allow install version from master, now used in testing
4 changes: 2 additions & 2 deletions requirements/nominal_test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@

pandas >1.4.0, <=2.2.3 # cannot pin version due to numpy version incompatibility
dython ==0.7.6 ; python_version <"3.9"
dython ~=0.7.8 ; python_version > "3.8" # we do not use `> =`
scipy >1.0.0, <1.15.0 # cannot pin version due to some version conflicts with `oldest` CI configuration
dython ==0.7.9 ; python_version > "3.8" # we do not use `> =`
scipy >1.0.0, <1.16.0 # cannot pin version due to some version conflicts with `oldest` CI configuration
statsmodels >0.13.5, <0.15.0
2 changes: 1 addition & 1 deletion requirements/segmentation_test.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment

scipy >1.0.0, <1.15.0
scipy >1.0.0, <1.16.0
monai ==1.3.2 ; python_version < "3.9"
monai ==1.4.0 ; python_version > "3.8"
2 changes: 1 addition & 1 deletion requirements/text_test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ jiwer >=2.3.0, <3.1.0
rouge-score >0.1.0, <=0.1.2
bert_score ==0.3.13
huggingface-hub <0.28
sacrebleu >=2.3.0, <2.5.0
sacrebleu >=2.3.0, <2.6.0

mecab-ko >=1.0.0, <1.1.0 ; python_version < "3.12" # strict # todo: unpin python_version
mecab-ko-dic >=1.0.0, <1.1.0 ; python_version < "3.12" # todo: unpin python_version
2 changes: 1 addition & 1 deletion src/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
MANUAL_SEED = doctest.register_optionflag("MANUAL_SEED")

@pytest.fixture(autouse=True)
def reset_random_seed(seed: int = 42) -> None: # noqa: PT004
def reset_random_seed(seed: int = 42) -> None:
"""Reset the random seed before running each doctest."""
import random

Expand Down
Loading

0 comments on commit 03f70ae

Please sign in to comment.