From 9512accc5054c9d0fd15b5a422b466c83b96e108 Mon Sep 17 00:00:00 2001 From: Jirka Date: Wed, 30 Jun 2021 10:59:35 +0200 Subject: [PATCH] rev --- torchmetrics/metric.py | 2 +- torchmetrics/retrieval/retrieval_metric.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index c691a3a5b0a..7151643be55 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -335,7 +335,7 @@ def wrapped_func(*args: Any, **kwargs: Any) -> Any: return wrapped_func @abstractmethod - def update(self, *_: Any) -> None: + def update(self, *_: Any, **__: Any) -> None: """ Override this method to update the state variables of your metric class. """ diff --git a/torchmetrics/retrieval/retrieval_metric.py b/torchmetrics/retrieval/retrieval_metric.py index 1782708c34f..1ee9e5f656c 100644 --- a/torchmetrics/retrieval/retrieval_metric.py +++ b/torchmetrics/retrieval/retrieval_metric.py @@ -90,7 +90,7 @@ def __init__( self.add_state("preds", default=[], dist_reduce_fx=None) self.add_state("target", default=[], dist_reduce_fx=None) - def update(self, preds: Tensor, target: Tensor, indexes: Tensor) -> None: + def update(self, preds: Tensor, target: Tensor, indexes: Tensor) -> None: # type: ignore """ Check shape, check and convert dtypes, flatten and add to accumulators. """ if indexes is None: raise ValueError("Argument `indexes` cannot be None")