Skip to content

Commit

Permalink
Merge branch 'master' into cpu_topk_performance
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki authored Aug 3, 2021
2 parents 18d142d + 21fe0ca commit e03878a
Show file tree
Hide file tree
Showing 133 changed files with 553 additions and 804 deletions.
4 changes: 2 additions & 2 deletions .github/set-minimal-versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


def set_min_torch_by_python(fpath: str = "requirements.txt") -> None:
"""set minimal torch version"""
"""set minimal torch version."""
py_ver = f"{sys.version_info.major}.{sys.version_info.minor}"
if py_ver not in LUT_PYTHON_TORCH:
return
Expand All @@ -24,7 +24,7 @@ def set_min_torch_by_python(fpath: str = "requirements.txt") -> None:


def replace_min_requirements(fpath: str) -> None:
"""replace all `>=` by `==` in given file"""
"""replace all `>=` by `==` in given file."""
logging.info(f"processing: {fpath}")
with open(fpath) as fp:
req = fp.read()
Expand Down
6 changes: 6 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ repos:
args: [--py36-plus]
name: Upgrade code

- repo: https://github.com/myint/docformatter
rev: v1.4
hooks:
- id: docformatter
args: [--in-place, --wrap-summaries=115, --wrap-descriptions=120]

- repo: https://github.com/PyCQA/isort
rev: 5.9.2
hooks:
Expand Down
40 changes: 40 additions & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@

.. _scikit-learn's implementation of SMAPE: https://github.com/scikit-learn/scikit-learn/blob/15a949460/sklearn/metrics/_regression.py#L197
.. _scikit-learn's implementation of MAPE: https://github.com/scikit-learn/scikit-learn/blob/15a949460/sklearn/metrics/_regression.py#L197
.. _Mean Average Precision: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Mean_average_precision
.. _Fall-out: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Fall-out
.. _Normalized Discounted Cumulative Gain: https://en.wikipedia.org/wiki/Discounted_cumulative_gain
.. _IR Precision: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Precision
.. _IR Recall: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Recall
.. _Accuracy: https://en.wikipedia.org/wiki/Accuracy_and_precision
.. _SMAPE: https://en.wikipedia.org/wiki/Symmetric_mean_absolute_percentage_error
.. _SNR: https://en.wikipedia.org/wiki/Signal-to-noise_ratio
.. _ROC AUC: https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Further_interpretations
.. _Cohen's kappa score: https://en.wikipedia.org/wiki/Cohen%27s_kappa
.. _scikit-learn's implementation of confusion matrix: https://scikit-learn.org/stable/modules/model_evaluation.html#confusion-matrix
.. _confusion matrix gets calculated per label: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.multilabel_confusion_matrix.html
.. _F-score: https://en.wikipedia.org/wiki/F-score
.. _Hamming distance: https://en.wikipedia.org/wiki/Hamming_distance
.. _Hinge loss: https://en.wikipedia.org/wiki/Hinge_loss
.. _Jaccard index: https://en.wikipedia.org/wiki/Jaccard_index
.. _KL divergence: https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence
.. _Matthews correlation coefficient: https://en.wikipedia.org/wiki/Matthews_correlation_coefficient
.. _Precision: https://en.wikipedia.org/wiki/Precision_and_recall
.. _Recall: https://en.wikipedia.org/wiki/Precision_and_recall
.. _Specificity: https://en.wikipedia.org/wiki/Sensitivity_and_specificity
.. _Type I and Type II errors: https://en.wikipedia.org/wiki/Type_I_and_type_II_errors
.. _confusion matrix: https://en.wikipedia.org/wiki/Confusion_matrix#Table_of_confusion
.. _sklearn averaging methods: https://scikit-learn.org/stable/modules/model_evaluation.html#multiclass-and-multilabel-classification
.. _Cosine Similarity: https://en.wikipedia.org/wiki/Cosine_similarity
.. _coefficient of determination: https://en.wikipedia.org/wiki/Coefficient_of_determination
.. _spearmans rank correlation coefficient: https://en.wikipedia.org/wiki/Spearman%27s_rank_correlation_coefficient
.. _WER: https://en.wikipedia.org/wiki/Word_error_rate
.. _FID: https://en.wikipedia.org/wiki/Fr%C3%A9chet_inception_distance
.. _mean-squared-error: https://en.wikipedia.org/wiki/Mean_squared_error
.. _SSIM: https://en.wikipedia.org/wiki/Structural_similarity
.. _explained variance: https://en.wikipedia.org/wiki/Explained_variation
.. _IR Average precision: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision
.. _IR Fall-out: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Fall-out
.. _MAPE implementation returns: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.mean_absolute_percentage_error.html
.. _mean squared logarithmic error: https://scikit-learn.org/stable/modules/model_evaluation.html#mean-squared-log-error
.. _Mean Reciprocal Rank: https://en.wikipedia.org/wiki/Mean_reciprocal_rank
2 changes: 2 additions & 0 deletions docs/source/references/functional.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
.. role:: hidden
:class: hidden-section

.. include:: ../links.rst

##################
Functional metrics
##################
Expand Down
2 changes: 2 additions & 0 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
Module metrics
##############

.. include:: ../links.rst

**********
Base class
**********
Expand Down
4 changes: 1 addition & 3 deletions integrations/lightning/boring_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ def __len__(self):

class BoringModel(LightningModule):
def __init__(self):
"""
Testing PL Module
"""Testing PL Module.
Use as follows:
- subclass
Expand All @@ -58,7 +57,6 @@ def training_step(...):
model = BaseTestModel()
model.training_epoch_end = None
"""
super().__init__()
self.layer = torch.nn.Linear(32, 2)
Expand Down
1 change: 1 addition & 0 deletions integrations/test_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def training_epoch_end(self, outs):
@pytest.mark.skipif(not _LIGHTNING_GREATER_EQUAL_1_3, reason="test requires lightning v1.3 or higher")
def test_metrics_reset(tmpdir):
"""Tests that metrics are reset correctly after the end of the train/val/test epoch.
Taken from:
https://github.com/PyTorchLightning/pytorch-lightning/pull/7055
"""
Expand Down
4 changes: 2 additions & 2 deletions tests/audio/test_pit.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def naive_implementation_pit_scipy(
metric_func: Callable,
eval_func: str,
) -> Tuple[Tensor, Tensor]:
"""A naive implementation of `Permutation Invariant Training` based on Scipy
"""A naive implementation of `Permutation Invariant Training` based on Scipy.
Args:
preds: predictions, shape[batch, spk, time]
Expand Down Expand Up @@ -83,7 +83,7 @@ def naive_implementation_pit_scipy(


def _average_metric(preds: Tensor, target: Tensor, metric_func: Callable) -> Tensor:
"""average the metric values
"""average the metric values.
Args:
preds: predictions, shape[batch, spk, time]
Expand Down
13 changes: 5 additions & 8 deletions tests/bases/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def test_device_and_dtype_transfer_metriccollection(tmpdir):


def test_metric_collection_wrong_input(tmpdir):
"""Check that errors are raised on wrong input"""
"""Check that errors are raised on wrong input."""
dms = DummyMetricSum()

# Not all input are metrics (list)
Expand All @@ -105,9 +105,8 @@ def test_metric_collection_wrong_input(tmpdir):


def test_metric_collection_args_kwargs(tmpdir):
"""Check that args and kwargs gets passed correctly in metric collection,
Checks both update and forward method
"""
"""Check that args and kwargs gets passed correctly in metric collection, Checks both update and forward
method."""
m1 = DummyMetricSum()
m2 = DummyMetricDiff()

Expand Down Expand Up @@ -143,7 +142,7 @@ def test_metric_collection_args_kwargs(tmpdir):
],
)
def test_metric_collection_prefix_postfix_args(prefix, postfix):
"""Test that the prefix arg alters the keywords in the output"""
"""Test that the prefix arg alters the keywords in the output."""
m1 = DummyMetricSum()
m2 = DummyMetricDiff()
names = ["DummyMetricSum", "DummyMetricDiff"]
Expand Down Expand Up @@ -192,9 +191,7 @@ def test_metric_collection_prefix_postfix_args(prefix, postfix):


def test_metric_collection_repr():
"""
Test MetricCollection
"""
"""Test MetricCollection."""

class A(DummyMetricSum):
pass
Expand Down
8 changes: 3 additions & 5 deletions tests/bases/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def compute(self):

@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")
def test_non_contiguous_tensors():
"""Test that gather_all operation works for non contiguous tensors"""
"""Test that gather_all operation works for non contiguous tensors."""
torch.multiprocessing.spawn(_test_non_contiguous_tensors, args=(2,), nprocs=2)


Expand Down Expand Up @@ -225,8 +225,6 @@ def reload_state_dict(state_dict, expected_x, expected_c):

@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")
def test_state_dict_is_synced(tmpdir):
"""
This test asserts that metrics are synced while creating the state
dict but restored after to continue accumulation.
"""
"""This test asserts that metrics are synced while creating the state dict but restored after to continue
accumulation."""
torch.multiprocessing.spawn(_test_state_dict_is_synced, args=(2, tmpdir), nprocs=2)
12 changes: 6 additions & 6 deletions tests/bases/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def test_pickle(tmpdir):


def test_state_dict(tmpdir):
"""test that metric states can be removed and added to state dict"""
"""test that metric states can be removed and added to state dict."""
metric = DummyMetric()
assert metric.state_dict() == OrderedDict()
metric.persistent(True)
Expand All @@ -220,7 +220,7 @@ def test_state_dict(tmpdir):


def test_load_state_dict(tmpdir):
"""test that metric states can be loaded with state dict"""
"""test that metric states can be loaded with state dict."""
metric = DummyMetricSum()
metric.persistent(True)
metric.update(5)
Expand All @@ -230,7 +230,7 @@ def test_load_state_dict(tmpdir):


def test_child_metric_state_dict():
"""test that child metric states will be added to parent state dict"""
"""test that child metric states will be added to parent state dict."""

class TestModule(nn.Module):
def __init__(self):
Expand Down Expand Up @@ -270,7 +270,7 @@ def test_device_and_dtype_transfer(tmpdir):


def test_warning_on_compute_before_update():
"""test that an warning is raised if user tries to call compute before update"""
"""test that an warning is raised if user tries to call compute before update."""
metric = DummyMetricSum()

# make sure everything is fine with forward
Expand All @@ -293,13 +293,13 @@ def test_warning_on_compute_before_update():


def test_metric_scripts():
"""test that metrics are scriptable"""
"""test that metrics are scriptable."""
torch.jit.script(DummyMetric())
torch.jit.script(DummyMetricSum())


def test_metric_forward_cache_reset():
"""test that forward cache is reset when `reset` is called"""
"""test that forward cache is reset when `reset` is called."""
metric = DummyMetricSum()
_ = metric(2.0)
assert metric._forward_cache == 2.0
Expand Down
5 changes: 2 additions & 3 deletions tests/classification/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,9 +327,8 @@ def test_average_accuracy_bin(preds, target, num_classes, exp_result, average, m
"ignore_index, expected", [(None, torch.tensor([1.0, np.nan])), (0, torch.tensor([np.nan, np.nan]))]
)
def test_class_not_present(metric_class, metric_fn, ignore_index, expected):
"""This tests that when metric is computed per class and a given class is not present
in both the `preds` and `target`, the resulting score is `nan`.
"""
"""This tests that when metric is computed per class and a given class is not present in both the `preds` and
`target`, the resulting score is `nan`."""
preds = torch.tensor([0, 0, 0])
target = torch.tensor([0, 0, 0])
num_classes = 2
Expand Down
10 changes: 5 additions & 5 deletions tests/classification/test_auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,8 @@ def test_auroc_differentiability(self, preds, target, sk_metric, num_classes, av


def test_error_on_different_mode():
"""test that an error is raised if the user pass in data of
different modes (binary, multi-label, multi-class)
"""
"""test that an error is raised if the user pass in data of different modes (binary, multi-label, multi-
class)"""
metric = AUROC()
# pass in multi-class data
metric.update(torch.randn(10, 5).softmax(dim=-1), torch.randint(0, 5, (10,)))
Expand All @@ -187,8 +186,9 @@ def test_error_multiclass_no_num_classes():


def test_weighted_with_empty_classes():
"""Tests that weighted multiclass AUROC calculation yields the same results if a new
but empty class exists. Tests that the proper warnings and errors are raised
"""Tests that weighted multiclass AUROC calculation yields the same results if a new but empty class exists.
Tests that the proper warnings and errors are raised
"""
preds = torch.tensor(
[
Expand Down
11 changes: 5 additions & 6 deletions tests/classification/test_f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,8 @@ def test_zero_division(metric_class, metric_fn):
],
)
def test_no_support(metric_class, metric_fn):
"""This tests a rare edge case, where there is only one class present
"""This tests a rare edge case, where there is only one class present.
in target, and ignore_index is set to exactly that class - and the
average method is equal to 'weighted'.
Expand All @@ -182,9 +183,8 @@ def test_no_support(metric_class, metric_fn):
"ignore_index, expected", [(None, torch.tensor([1.0, np.nan])), (0, torch.tensor([np.nan, np.nan]))]
)
def test_class_not_present(metric_class, metric_fn, ignore_index, expected):
"""This tests that when metric is computed per class and a given class is not present
in both the `preds` and `target`, the resulting score is `nan`.
"""
"""This tests that when metric is computed per class and a given class is not present in both the `preds` and
`target`, the resulting score is `nan`."""
preds = torch.tensor([0, 0, 0])
target = torch.tensor([0, 0, 0])
num_classes = 2
Expand Down Expand Up @@ -412,8 +412,7 @@ def test_top_k(
):
"""A simple test to check that top_k works as expected.
Just a sanity check, the tests in StatScores should already guarantee
the corectness of results.
Just a sanity check, the tests in StatScores should already guarantee the corectness of results.
"""
class_metric = metric_class(top_k=k, average=average, num_classes=3)
class_metric.update(preds, target)
Expand Down
12 changes: 6 additions & 6 deletions tests/classification/test_precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,8 @@ def test_zero_division(metric_class, metric_fn):

@pytest.mark.parametrize("metric_class, metric_fn", [(Recall, recall), (Precision, precision)])
def test_no_support(metric_class, metric_fn):
"""This tests a rare edge case, where there is only one class present
"""This tests a rare edge case, where there is only one class present.
in target, and ignore_index is set to exactly that class - and the
average method is equal to 'weighted'.
Expand Down Expand Up @@ -355,8 +356,8 @@ def test_precision_recall_differentiability(
def test_precision_recall_joint(average):
"""A simple test of the joint precision_recall metric.
No need to test this thorougly, as it is just a combination of precision and recall,
which are already tested thoroughly.
No need to test this thorougly, as it is just a combination of precision and recall, which are already tested
thoroughly.
"""

precision_result = precision(
Expand Down Expand Up @@ -422,9 +423,8 @@ def test_top_k(
"ignore_index, expected", [(None, torch.tensor([1.0, np.nan])), (0, torch.tensor([np.nan, np.nan]))]
)
def test_class_not_present(metric_class, metric_fn, ignore_index, expected):
"""This tests that when metric is computed per class and a given class is not present
in both the `preds` and `target`, the resulting score is `nan`.
"""
"""This tests that when metric is computed per class and a given class is not present in both the `preds` and
`target`, the resulting score is `nan`."""
preds = torch.tensor([0, 0, 0])
target = torch.tensor([0, 0, 0])
num_classes = 2
Expand Down
2 changes: 1 addition & 1 deletion tests/classification/test_precision_recall_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@


def _sk_precision_recall_curve(y_true, probas_pred, num_classes=1):
"""Adjusted comparison function that can also handles multiclass"""
"""Adjusted comparison function that can also handles multiclass."""
if num_classes == 1:
return sk_precision_recall_curve(y_true, probas_pred)

Expand Down
2 changes: 1 addition & 1 deletion tests/classification/test_roc.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@


def _sk_roc_curve(y_true, probas_pred, num_classes: int = 1, multilabel: bool = False):
"""Adjusted comparison function that can also handles multiclass"""
"""Adjusted comparison function that can also handles multiclass."""
if num_classes == 1:
return sk_roc_curve(y_true, probas_pred, drop_intermediate=False)

Expand Down
8 changes: 4 additions & 4 deletions tests/classification/test_specificity.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,8 @@ def test_zero_division(metric_class, metric_fn):

@pytest.mark.parametrize("metric_class, metric_fn", [(Specificity, specificity)])
def test_no_support(metric_class, metric_fn):
"""This tests a rare edge case, where there is only one class present
"""This tests a rare edge case, where there is only one class present.
in target, and ignore_index is set to exactly that class - and the
average method is equal to 'weighted'.
Expand Down Expand Up @@ -396,9 +397,8 @@ def test_top_k(
"ignore_index, expected", [(None, torch.tensor([0.0, np.nan])), (0, torch.tensor([np.nan, np.nan]))]
)
def test_class_not_present(metric_class, metric_fn, ignore_index, expected):
"""This tests that when metric is computed per class and a given class is not present
in both the `preds` and `target`, the resulting score is `nan`.
"""
"""This tests that when metric is computed per class and a given class is not present in both the `preds` and
`target`, the resulting score is `nan`."""
preds = torch.tensor([0, 0, 0])
target = torch.tensor([0, 0, 0])
num_classes = 2
Expand Down
Loading

0 comments on commit e03878a

Please sign in to comment.