Skip to content

Commit

Permalink
Merge branch 'master' into ter
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda authored Dec 6, 2021
2 parents 27ef8f9 + 8d5b8ba commit 0d4828e
Show file tree
Hide file tree
Showing 23 changed files with 692 additions and 400 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/docs-check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -82,5 +82,5 @@ jobs:
uses: actions/upload-artifact@v2
with:
name: docs-build
path: docs/build/html/
if: success()
path: docs/build/
if: always()
19 changes: 14 additions & 5 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `CHRFScore` ([#641](https://github.com/PyTorchLightning/metrics/pull/641))
- `TER` ([#646](https://github.com/PyTorchLightning/metrics/pull/646))

- Add a default VS Code devcontainer configuration ([#621](https://github.com/PyTorchLightning/metrics/pull/621))

- Added Signal to Distortion Ratio (`SDR`) to `audio` package ([#565](https://github.com/PyTorchLightning/metrics/pull/565))
- Add a default VSCode devcontainer configuration ([#621](https://github.com/PyTorchLightning/metrics/pull/621))


- Added Signal to Distortion Ratio (`SDR`) to audio package ([#565](https://github.com/PyTorchLightning/metrics/pull/565))


- Added `MinMaxMetric` to wrappers ([#556](https://github.com/PyTorchLightning/metrics/pull/556))
Expand All @@ -31,7 +33,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Scalar metrics will now consistently have additional dimensions squeezed ([#622](https://github.com/PyTorchLightning/metrics/pull/622))


- Use `torch.topk` instead of `torch.argsort` in retrieval precision for speedup ([#627](https://github.com/PyTorchLightning/metrics/pull/627))
- `BLEUScore` now expects untokenized input to stay consistent with all the other text metrics ([#640](https://github.com/PyTorchLightning/metrics/pull/640))


### Deprecated
Expand All @@ -45,12 +47,19 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Fix empty predictions in MAP metric ([#594](https://github.com/PyTorchLightning/metrics/pull/594), [#624](https://github.com/PyTorchLightning/metrics/pull/624))


- Fix edge case of AUROC with `average=weighted` on GPU ([#606](https://github.com/PyTorchLightning/metrics/pull/606))
## [0.6.1] - 2021-12-06

### Changed

- Migrate MAP metrics from pycocotools to PyTorch ([#632](https://github.com/PyTorchLightning/metrics/pull/632))
- Use `torch.topk` instead of `torch.argsort` in retrieval precision for speedup ([#627](https://github.com/PyTorchLightning/metrics/pull/627))

### Fixed

- Fix empty predictions in MAP metric ([#594](https://github.com/PyTorchLightning/metrics/pull/594), [#610](https://github.com/PyTorchLightning/metrics/pull/610), [#624](https://github.com/PyTorchLightning/metrics/pull/624))
- Fix edge case of AUROC with `average=weighted` on GPU ([#606](https://github.com/PyTorchLightning/metrics/pull/606))
- Fixed `forward` in compositional metrics ([#645](https://github.com/PyTorchLightning/metrics/pull/645))


Expand Down
2 changes: 2 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ clean:
rm -rf ./docs/source/generated
rm -rf ./docs/source/*/generated
rm -rf ./docs/source/api
rm -rf build
rm -rf dist

test: clean env

Expand Down
10 changes: 5 additions & 5 deletions tests/bases/test_composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,10 +310,10 @@ def test_metrics_or(second_operand, expected_result):
@pytest.mark.parametrize(
["second_operand", "expected_result"],
[
pytest.param(DummyMetric(2), tensor(4)),
pytest.param(2, tensor(4)),
(DummyMetric(2), tensor(4)),
(2, tensor(4)),
pytest.param(2.0, tensor(4.0), marks=pytest.mark.skipif(**_MARK_TORCH_MIN_1_6)),
pytest.param(tensor(2), tensor(4)),
(tensor(2), tensor(4)),
],
)
def test_metrics_pow(second_operand, expected_result):
Expand Down Expand Up @@ -376,8 +376,8 @@ def test_metrics_rmod(first_operand, expected_result):
@pytest.mark.parametrize(
"first_operand,expected_result",
[
pytest.param(DummyMetric(2), tensor(4)),
pytest.param(2, tensor(4)),
(DummyMetric(2), tensor(4)),
(2, tensor(4)),
pytest.param(2.0, tensor(4.0), marks=pytest.mark.skipif(**_MARK_TORCH_MIN_1_6)),
],
)
Expand Down
10 changes: 5 additions & 5 deletions tests/classification/test_auc.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,11 @@ def test_auc_differentiability(self, x, y, reorder):
@pytest.mark.parametrize(
["x", "y", "expected"],
[
pytest.param([0, 1], [0, 1], 0.5),
pytest.param([1, 0], [0, 1], 0.5),
pytest.param([1, 0, 0], [0, 1, 1], 0.5),
pytest.param([0, 1], [1, 1], 1),
pytest.param([0, 0.5, 1], [0, 0.5, 1], 0.5),
([0, 1], [0, 1], 0.5),
([1, 0], [0, 1], 0.5),
([1, 0, 0], [0, 1, 1], 0.5),
([0, 1], [1, 1], 1),
([0, 0.5, 1], [0, 0.5, 1], 0.5),
],
)
def test_auc(x, y, expected, unsqueeze_x, unsqueeze_y):
Expand Down
4 changes: 2 additions & 2 deletions tests/classification/test_average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,9 @@ def test_average_precision_differentiability(self, preds, sk_metric, target, num
# And a constant score
# The precision is then the fraction of positive whatever the recall
# is, as there is only one threshold:
pytest.param(tensor([1, 1, 1, 1]), tensor([0, 0, 0, 1]), 0.25),
(tensor([1, 1, 1, 1]), tensor([0, 0, 0, 1]), 0.25),
# With threshold 0.8 : 1 TP and 2 TN and one FN
pytest.param(tensor([0.6, 0.7, 0.8, 9]), tensor([1, 0, 0, 1]), 0.75),
(tensor([0.6, 0.7, 0.8, 9]), tensor([1, 0, 0, 1]), 0.75),
],
)
def test_average_precision(scores, target, expected_score):
Expand Down
8 changes: 4 additions & 4 deletions tests/classification/test_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
@pytest.mark.parametrize(
["pred", "target", "expected"],
[
pytest.param([[0, 0], [1, 1]], [[0, 0], [1, 1]], 1.0),
pytest.param([[1, 1], [0, 0]], [[0, 0], [1, 1]], 0.0),
pytest.param([[1, 1], [1, 1]], [[1, 1], [0, 0]], 2 / 3),
pytest.param([[1, 1], [0, 0]], [[1, 1], [0, 0]], 1.0),
([[0, 0], [1, 1]], [[0, 0], [1, 1]], 1.0),
([[1, 1], [0, 0]], [[0, 0], [1, 1]], 0.0),
([[1, 1], [1, 1]], [[1, 1], [0, 0]], 2 / 3),
([[1, 1], [0, 0]], [[1, 1], [0, 0]], 1.0),
],
)
def test_dice_score(pred, target, expected):
Expand Down
56 changes: 28 additions & 28 deletions tests/classification/test_iou.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,12 +139,12 @@ def test_iou_differentiability(self, reduction, preds, target, sk_metric, num_cl
@pytest.mark.parametrize(
["half_ones", "reduction", "ignore_index", "expected"],
[
pytest.param(False, "none", None, Tensor([1, 1, 1])),
pytest.param(False, "elementwise_mean", None, Tensor([1])),
pytest.param(False, "none", 0, Tensor([1, 1])),
pytest.param(True, "none", None, Tensor([0.5, 0.5, 0.5])),
pytest.param(True, "elementwise_mean", None, Tensor([0.5])),
pytest.param(True, "none", 0, Tensor([2 / 3, 1 / 2])),
(False, "none", None, Tensor([1, 1, 1])),
(False, "elementwise_mean", None, Tensor([1])),
(False, "none", 0, Tensor([1, 1])),
(True, "none", None, Tensor([0.5, 0.5, 0.5])),
(True, "elementwise_mean", None, Tensor([0.5])),
(True, "none", 0, Tensor([2 / 3, 1 / 2])),
],
)
def test_iou(half_ones, reduction, ignore_index, expected):
Expand All @@ -168,30 +168,30 @@ def test_iou(half_ones, reduction, ignore_index, expected):
# Note that -1 is used as the absent_score in almost all tests here to distinguish it from the range of valid
# scores the function can return ([0., 1.] range, inclusive).
# 2 classes, class 0 is correct everywhere, class 1 is absent.
pytest.param([0], [0], None, -1.0, 2, [1.0, -1.0]),
pytest.param([0, 0], [0, 0], None, -1.0, 2, [1.0, -1.0]),
([0], [0], None, -1.0, 2, [1.0, -1.0]),
([0, 0], [0, 0], None, -1.0, 2, [1.0, -1.0]),
# absent_score not applied if only class 0 is present and it's the only class.
pytest.param([0], [0], None, -1.0, 1, [1.0]),
([0], [0], None, -1.0, 1, [1.0]),
# 2 classes, class 1 is correct everywhere, class 0 is absent.
pytest.param([1], [1], None, -1.0, 2, [-1.0, 1.0]),
pytest.param([1, 1], [1, 1], None, -1.0, 2, [-1.0, 1.0]),
([1], [1], None, -1.0, 2, [-1.0, 1.0]),
([1, 1], [1, 1], None, -1.0, 2, [-1.0, 1.0]),
# When 0 index ignored, class 0 does not get a score (not even the absent_score).
pytest.param([1], [1], 0, -1.0, 2, [1.0]),
([1], [1], 0, -1.0, 2, [1.0]),
# 3 classes. Only 0 and 2 are present, and are perfectly predicted. 1 should get absent_score.
pytest.param([0, 2], [0, 2], None, -1.0, 3, [1.0, -1.0, 1.0]),
pytest.param([2, 0], [2, 0], None, -1.0, 3, [1.0, -1.0, 1.0]),
([0, 2], [0, 2], None, -1.0, 3, [1.0, -1.0, 1.0]),
([2, 0], [2, 0], None, -1.0, 3, [1.0, -1.0, 1.0]),
# 3 classes. Only 0 and 1 are present, and are perfectly predicted. 2 should get absent_score.
pytest.param([0, 1], [0, 1], None, -1.0, 3, [1.0, 1.0, -1.0]),
pytest.param([1, 0], [1, 0], None, -1.0, 3, [1.0, 1.0, -1.0]),
([0, 1], [0, 1], None, -1.0, 3, [1.0, 1.0, -1.0]),
([1, 0], [1, 0], None, -1.0, 3, [1.0, 1.0, -1.0]),
# 3 classes, class 0 is 0.5 IoU, class 1 is 0 IoU (in pred but not target; should not get absent_score), class
# 2 is absent.
pytest.param([0, 1], [0, 0], None, -1.0, 3, [0.5, 0.0, -1.0]),
([0, 1], [0, 0], None, -1.0, 3, [0.5, 0.0, -1.0]),
# 3 classes, class 0 is 0.5 IoU, class 1 is 0 IoU (in target but not pred; should not get absent_score), class
# 2 is absent.
pytest.param([0, 0], [0, 1], None, -1.0, 3, [0.5, 0.0, -1.0]),
([0, 0], [0, 1], None, -1.0, 3, [0.5, 0.0, -1.0]),
# Sanity checks with absent_score of 1.0.
pytest.param([0, 2], [0, 2], None, 1.0, 3, [1.0, 1.0, 1.0]),
pytest.param([0, 2], [0, 2], 0, 1.0, 3, [1.0, 1.0]),
([0, 2], [0, 2], None, 1.0, 3, [1.0, 1.0, 1.0]),
([0, 2], [0, 2], 0, 1.0, 3, [1.0, 1.0]),
],
)
def test_iou_absent_score(pred, target, ignore_index, absent_score, num_classes, expected):
Expand All @@ -212,16 +212,16 @@ def test_iou_absent_score(pred, target, ignore_index, absent_score, num_classes,
["pred", "target", "ignore_index", "num_classes", "reduction", "expected"],
[
# Ignoring an index outside of [0, num_classes-1] should have no effect.
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], None, 3, "none", [1, 1 / 2, 2 / 3]),
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], -1, 3, "none", [1, 1 / 2, 2 / 3]),
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 255, 3, "none", [1, 1 / 2, 2 / 3]),
([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], None, 3, "none", [1, 1 / 2, 2 / 3]),
([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], -1, 3, "none", [1, 1 / 2, 2 / 3]),
([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 255, 3, "none", [1, 1 / 2, 2 / 3]),
# Ignoring a valid index drops only that index from the result.
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, "none", [1 / 2, 2 / 3]),
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 1, 3, "none", [1, 2 / 3]),
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 2, 3, "none", [1, 1]),
([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, "none", [1 / 2, 2 / 3]),
([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 1, 3, "none", [1, 2 / 3]),
([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 2, 3, "none", [1, 1]),
# When reducing to mean or sum, the ignored index does not contribute to the output.
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, "elementwise_mean", [7 / 12]),
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, "sum", [7 / 6]),
([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, "elementwise_mean", [7 / 12]),
([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, "sum", [7 / 6]),
],
)
def test_iou_ignore_index(pred, target, ignore_index, num_classes, reduction, expected):
Expand Down
2 changes: 1 addition & 1 deletion tests/classification/test_precision_recall_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def test_precision_recall_curve_differentiability(self, preds, target, sk_metric

@pytest.mark.parametrize(
["pred", "target", "expected_p", "expected_r", "expected_t"],
[pytest.param([1, 2, 3, 4], [1, 0, 0, 1], [0.5, 1 / 3, 0.5, 1.0, 1.0], [1, 0.5, 0.5, 0.5, 0.0], [1, 2, 3, 4])],
[([1, 2, 3, 4], [1, 0, 0, 1], [0.5, 1 / 3, 0.5, 1.0, 1.0], [1, 0.5, 0.5, 0.5, 0.0], [1, 2, 3, 4])],
)
def test_pr_curve(pred, target, expected_p, expected_r, expected_t):
p, r, t = precision_recall_curve(tensor(pred), tensor(target))
Expand Down
10 changes: 5 additions & 5 deletions tests/classification/test_roc.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,11 @@ def test_roc_differentiability(self, preds, target, sk_metric, num_classes):
@pytest.mark.parametrize(
["pred", "target", "expected_tpr", "expected_fpr"],
[
pytest.param([0, 1], [0, 1], [0, 1, 1], [0, 0, 1]),
pytest.param([1, 0], [0, 1], [0, 0, 1], [0, 1, 1]),
pytest.param([1, 1], [1, 0], [0, 1], [0, 1]),
pytest.param([1, 0], [1, 0], [0, 1, 1], [0, 0, 1]),
pytest.param([0.5, 0.5], [0, 1], [0, 1], [0, 1]),
([0, 1], [0, 1], [0, 1, 1], [0, 0, 1]),
([1, 0], [0, 1], [0, 0, 1], [0, 1, 1]),
([1, 1], [1, 0], [0, 1], [0, 1]),
([1, 0], [1, 0], [0, 1, 1], [0, 0, 1]),
([0.5, 0.5], [0, 1], [0, 1], [0, 1]),
],
)
def test_roc_curve(pred, target, expected_tpr, expected_fpr):
Expand Down
41 changes: 14 additions & 27 deletions tests/detection/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,7 @@

from tests.helpers.testers import MetricTester
from torchmetrics.detection.map import MAP
from torchmetrics.utilities.imports import (
_PYCOCOTOOLS_AVAILABLE,
_TORCHVISION_AVAILABLE,
_TORCHVISION_GREATER_EQUAL_0_8,
)
from torchmetrics.utilities.imports import _TORCHVISION_AVAILABLE, _TORCHVISION_GREATER_EQUAL_0_8

Input = namedtuple("Input", ["preds", "target"])

Expand Down Expand Up @@ -59,7 +55,7 @@
), # coco image id 74
dict(
boxes=torch.Tensor([[0.00, 2.87, 601.00, 421.52]]),
scores=torch.Tensor([0.699, 0.423]),
scores=torch.Tensor([0.699]),
labels=torch.IntTensor([5]),
), # coco image id 133
],
Expand Down Expand Up @@ -164,10 +160,10 @@ def _compare_fn(preds, target) -> dict:
}


_pytest_condition = not (_PYCOCOTOOLS_AVAILABLE and _TORCHVISION_AVAILABLE and _TORCHVISION_GREATER_EQUAL_0_8)
_pytest_condition = not (_TORCHVISION_AVAILABLE and _TORCHVISION_GREATER_EQUAL_0_8)


@pytest.mark.skipif(_pytest_condition, reason="test requires that pycocotools and torchvision=>0.8.0 is installed")
@pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 is installed")
class TestMAP(MetricTester):
"""Test the MAP metric for object detection predictions.
Expand All @@ -194,7 +190,7 @@ def test_map(self, ddp):


# noinspection PyTypeChecker
@pytest.mark.skipif(_pytest_condition, reason="test requires that pycocotools and torchvision=>0.8.0 is installed")
@pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 is installed")
def test_error_on_wrong_init():
"""Test class raises the expected errors."""
MAP() # no error
Expand All @@ -203,20 +199,11 @@ def test_error_on_wrong_init():
MAP(class_metrics=0)


@pytest.mark.skipif(_pytest_condition, reason="test requires that pycocotools and torchvision=>0.8.0 is installed")
@pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 is installed")
def test_empty_preds():
"""Test empty predictions."""
metric = MAP()

metric.update(
[
dict(boxes=torch.Tensor([[]]), scores=torch.Tensor([]), labels=torch.IntTensor([])),
],
[
dict(boxes=torch.Tensor([[214.1500, 41.2900, 562.4100, 285.0700]]), labels=torch.IntTensor([4])),
],
)

metric.update(
[
dict(boxes=torch.Tensor([]), scores=torch.Tensor([]), labels=torch.IntTensor([])),
Expand All @@ -235,17 +222,17 @@ def test_empty_metric():
metric.compute()


@pytest.mark.skipif(_pytest_condition, reason="test requires that pycocotools and torchvision=>0.8.0 is installed")
@pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 is installed")
def test_error_on_wrong_input():
"""Test class input validation."""
metric = MAP()

metric.update([], []) # no error

with pytest.raises(ValueError, match="Expected argument `preds` to be of type List"):
with pytest.raises(ValueError, match="Expected argument `preds` to be of type Sequence"):
metric.update(torch.Tensor(), []) # type: ignore

with pytest.raises(ValueError, match="Expected argument `target` to be of type List"):
with pytest.raises(ValueError, match="Expected argument `target` to be of type Sequence"):
metric.update([], torch.Tensor()) # type: ignore

with pytest.raises(ValueError, match="Expected argument `preds` and `target` to have the same length"):
Expand Down Expand Up @@ -281,31 +268,31 @@ def test_error_on_wrong_input():
[dict(boxes=torch.IntTensor())],
)

with pytest.raises(ValueError, match="Expected all boxes in `preds` to be of type torch.Tensor"):
with pytest.raises(ValueError, match="Expected all boxes in `preds` to be of type Tensor"):
metric.update(
[dict(boxes=[], scores=torch.Tensor(), labels=torch.IntTensor())],
[dict(boxes=torch.Tensor(), labels=torch.IntTensor())],
)

with pytest.raises(ValueError, match="Expected all scores in `preds` to be of type torch.Tensor"):
with pytest.raises(ValueError, match="Expected all scores in `preds` to be of type Tensor"):
metric.update(
[dict(boxes=torch.Tensor(), scores=[], labels=torch.IntTensor())],
[dict(boxes=torch.Tensor(), labels=torch.IntTensor())],
)

with pytest.raises(ValueError, match="Expected all labels in `preds` to be of type torch.Tensor"):
with pytest.raises(ValueError, match="Expected all labels in `preds` to be of type Tensor"):
metric.update(
[dict(boxes=torch.Tensor(), scores=torch.Tensor(), labels=[])],
[dict(boxes=torch.Tensor(), labels=torch.IntTensor())],
)

with pytest.raises(ValueError, match="Expected all boxes in `target` to be of type torch.Tensor"):
with pytest.raises(ValueError, match="Expected all boxes in `target` to be of type Tensor"):
metric.update(
[dict(boxes=torch.Tensor(), scores=torch.Tensor(), labels=torch.IntTensor())],
[dict(boxes=[], labels=torch.IntTensor())],
)

with pytest.raises(ValueError, match="Expected all labels in `target` to be of type torch.Tensor"):
with pytest.raises(ValueError, match="Expected all labels in `target` to be of type Tensor"):
metric.update(
[dict(boxes=torch.Tensor(), scores=torch.Tensor(), labels=torch.IntTensor())],
[dict(boxes=torch.Tensor(), labels=[])],
Expand Down
18 changes: 9 additions & 9 deletions tests/image/test_ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,15 +107,15 @@ def test_ssim_half_gpu(self, preds, target, multichannel, kernel_size):
@pytest.mark.parametrize(
["pred", "target", "kernel", "sigma"],
[
pytest.param([1, 16, 16], [1, 16, 16], [11, 11], [1.5, 1.5]), # len(shape)
pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 11], [1.5]), # len(kernel), len(sigma)
pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11], [1.5, 1.5]), # len(kernel), len(sigma)
pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11], [1.5]), # len(kernel), len(sigma)
pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 0], [1.5, 1.5]), # invalid kernel input
pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 10], [1.5, 1.5]), # invalid kernel input
pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, -11], [1.5, 1.5]), # invalid kernel input
pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 11], [1.5, 0]), # invalid sigma input
pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 0], [1.5, -1.5]), # invalid sigma input
([1, 16, 16], [1, 16, 16], [11, 11], [1.5, 1.5]), # len(shape)
([1, 1, 16, 16], [1, 1, 16, 16], [11, 11], [1.5]), # len(kernel), len(sigma)
([1, 1, 16, 16], [1, 1, 16, 16], [11], [1.5, 1.5]), # len(kernel), len(sigma)
([1, 1, 16, 16], [1, 1, 16, 16], [11], [1.5]), # len(kernel), len(sigma)
([1, 1, 16, 16], [1, 1, 16, 16], [11, 0], [1.5, 1.5]), # invalid kernel input
([1, 1, 16, 16], [1, 1, 16, 16], [11, 10], [1.5, 1.5]), # invalid kernel input
([1, 1, 16, 16], [1, 1, 16, 16], [11, -11], [1.5, 1.5]), # invalid kernel input
([1, 1, 16, 16], [1, 1, 16, 16], [11, 11], [1.5, 0]), # invalid sigma input
([1, 1, 16, 16], [1, 1, 16, 16], [11, 0], [1.5, -1.5]), # invalid sigma input
],
)
def test_ssim_invalid_inputs(pred, target, kernel, sigma):
Expand Down
Loading

0 comments on commit 0d4828e

Please sign in to comment.