Skip to content

Commit

Permalink
Move CSI metric to right domain (#2309)
Browse files Browse the repository at this point in the history
* move csi to right domain
* fix docstrings
* move docs

(cherry picked from commit 548a51f)
  • Loading branch information
SkafteNicki authored and Borda committed Jan 17, 2024
1 parent 42ee556 commit f0e11da
Show file tree
Hide file tree
Showing 10 changed files with 37 additions and 34 deletions.
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
.. customcarditem::
:header: Critical Success Index (CSI)
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/image_classification.svg
:tags: Image

.. include:: ../links.rst

############################
Critical Success Index (CSI)
############################

Module Interface
________________

.. autoclass:: torchmetrics.image.CriticalSuccessIndex
:exclude-members: update, compute


Functional Interface
____________________

.. autofunction:: torchmetrics.functional.image.critical_success_index
.. customcarditem::
:header: Critical Success Index (CSI)
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/image_classification.svg
:tags: Regression

.. include:: ../links.rst

############################
Critical Success Index (CSI)
############################

Module Interface
________________

.. autoclass:: torchmetrics.regression.CriticalSuccessIndex
:exclude-members: update, compute


Functional Interface
____________________

.. autofunction:: torchmetrics.functional.regression.critical_success_index
2 changes: 2 additions & 0 deletions src/torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
from torchmetrics.regression import ( # noqa: E402
ConcordanceCorrCoef,
CosineSimilarity,
CriticalSuccessIndex,
ExplainedVariance,
KendallRankCorrCoef,
KLDivergence,
Expand Down Expand Up @@ -164,6 +165,7 @@
"ConfusionMatrix",
"CosineSimilarity",
"CramersV",
"CriticalSuccessIndex",
"Dice",
"TweedieDevianceScore",
"ErrorRelativeGlobalDimensionlessSynthesis",
Expand Down
2 changes: 2 additions & 0 deletions src/torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
from torchmetrics.functional.regression import (
concordance_corrcoef,
cosine_similarity,
critical_success_index,
explained_variance,
kendall_rank_corrcoef,
kl_divergence,
Expand Down Expand Up @@ -150,6 +151,7 @@
"cosine_similarity",
"cramers_v",
"cramers_v_matrix",
"critical_success_index",
"tweedie_deviance_score",
"dice",
"error_relative_global_dimensionless_synthesis",
Expand Down
2 changes: 0 additions & 2 deletions src/torchmetrics/functional/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.functional.image.csi import critical_success_index
from torchmetrics.functional.image.d_lambda import spectral_distortion_index
from torchmetrics.functional.image.d_s import spatial_distortion_index
from torchmetrics.functional.image.ergas import error_relative_global_dimensionless_synthesis
Expand Down Expand Up @@ -49,6 +48,5 @@
"visual_information_fidelity",
"learned_perceptual_image_patch_similarity",
"perceptual_path_length",
"critical_success_index",
"spatial_correlation_coefficient",
]
3 changes: 2 additions & 1 deletion src/torchmetrics/functional/regression/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from torchmetrics.functional.regression.concordance import concordance_corrcoef
from torchmetrics.functional.regression.cosine_similarity import cosine_similarity
from torchmetrics.functional.regression.csi import critical_success_index
from torchmetrics.functional.regression.explained_variance import explained_variance
from torchmetrics.functional.regression.kendall import kendall_rank_corrcoef
from torchmetrics.functional.regression.kl_divergence import kl_divergence
Expand All @@ -34,6 +34,7 @@
__all__ = [
"concordance_corrcoef",
"cosine_similarity",
"critical_success_index",
"explained_variance",
"kendall_rank_corrcoef",
"kl_divergence",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,15 +93,15 @@ def critical_success_index(
Example:
>>> import torch
>>> from torchmetrics.functional.image.csi import critical_success_index
>>> from torchmetrics.functional.regression import critical_success_index
>>> x = torch.Tensor([[0.2, 0.7], [0.9, 0.3]])
>>> y = torch.Tensor([[0.4, 0.2], [0.8, 0.6]])
>>> critical_success_index(x, y, 0.5)
tensor(0.3333)
Example:
>>> import torch
>>> from torchmetrics.functional.image.csi import critical_success_index
>>> from torchmetrics.functional.regression import critical_success_index
>>> x = torch.Tensor([[[0.2, 0.7], [0.9, 0.3]], [[0.2, 0.7], [0.9, 0.3]]])
>>> y = torch.Tensor([[[0.4, 0.2], [0.8, 0.6]], [[0.4, 0.2], [0.8, 0.6]]])
>>> critical_success_index(x, y, 0.5, keep_sequence_dim=0)
Expand Down
2 changes: 0 additions & 2 deletions src/torchmetrics/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.image.csi import CriticalSuccessIndex
from torchmetrics.image.d_lambda import SpectralDistortionIndex
from torchmetrics.image.d_s import SpatialDistortionIndex
from torchmetrics.image.ergas import ErrorRelativeGlobalDimensionlessSynthesis
Expand Down Expand Up @@ -46,7 +45,6 @@
"UniversalImageQualityIndex",
"VisualInformationFidelity",
"TotalVariation",
"CriticalSuccessIndex",
"SpatialCorrelationCoefficient",
]

Expand Down
2 changes: 2 additions & 0 deletions src/torchmetrics/regression/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from torchmetrics.regression.concordance import ConcordanceCorrCoef
from torchmetrics.regression.cosine_similarity import CosineSimilarity
from torchmetrics.regression.csi import CriticalSuccessIndex
from torchmetrics.regression.explained_variance import ExplainedVariance
from torchmetrics.regression.kendall import KendallRankCorrCoef
from torchmetrics.regression.kl_divergence import KLDivergence
Expand All @@ -33,6 +34,7 @@
__all__ = [
"ConcordanceCorrCoef",
"CosineSimilarity",
"CriticalSuccessIndex",
"ExplainedVariance",
"KendallRankCorrCoef",
"KLDivergence",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import torch

from torchmetrics.functional.image.csi import _critical_success_index_compute, _critical_success_index_update
from torchmetrics.functional.regression.csi import _critical_success_index_compute, _critical_success_index_update
from torchmetrics.metric import Metric
from torchmetrics.utilities import dim_zero_cat

Expand All @@ -39,7 +39,7 @@ class CriticalSuccessIndex(Metric):
Example:
>>> import torch
>>> from torchmetrics.image.csi import CriticalSuccessIndex
>>> from torchmetrics.regression import CriticalSuccessIndex
>>> x = torch.Tensor([[0.2, 0.7], [0.9, 0.3]])
>>> y = torch.Tensor([[0.4, 0.2], [0.8, 0.6]])
>>> csi = CriticalSuccessIndex(0.5)
Expand All @@ -48,7 +48,7 @@ class CriticalSuccessIndex(Metric):
Example:
>>> import torch
>>> from torchmetrics.image.csi import CriticalSuccessIndex
>>> from torchmetrics.regression import CriticalSuccessIndex
>>> x = torch.Tensor([[[0.2, 0.7], [0.9, 0.3]], [[0.2, 0.7], [0.9, 0.3]]])
>>> y = torch.Tensor([[[0.4, 0.2], [0.8, 0.6]], [[0.4, 0.2], [0.8, 0.6]]])
>>> csi = CriticalSuccessIndex(0.5, keep_sequence_dim=0)
Expand Down
4 changes: 2 additions & 2 deletions tests/unittests/image/test_csi.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
import pytest
import torch
from sklearn.metrics import jaccard_score
from torchmetrics.functional.image.csi import critical_success_index
from torchmetrics.image.csi import CriticalSuccessIndex
from torchmetrics.functional.regression.csi import critical_success_index
from torchmetrics.regression.csi import CriticalSuccessIndex

from unittests import BATCH_SIZE, NUM_BATCHES, _Input
from unittests.helpers import seed_all
Expand Down

0 comments on commit f0e11da

Please sign in to comment.