Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

typing: info retrieval #332

Merged
merged 9 commits into from
Jun 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

- Extend typing ([#330](https://github.com/PyTorchLightning/metrics/pull/330))
- Extend typing ([#330](https://github.com/PyTorchLightning/metrics/pull/330),
[#332](https://github.com/PyTorchLightning/metrics/pull/332))


### Deprecated
Expand Down
4 changes: 0 additions & 4 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,3 @@ ignore_errors = True
# todo: add proper typing to this module...
[mypy-torchmetrics.regression.*]
ignore_errors = True

# todo: add proper typing to this module...
[mypy-torchmetrics.retrieval.*]
ignore_errors = True
4 changes: 3 additions & 1 deletion torchmetrics/functional/retrieval/ndcg.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional

import torch
from torch import Tensor, tensor

Expand All @@ -22,7 +24,7 @@ def _dcg(target: Tensor) -> Tensor:
return (target / denom).sum()


def retrieval_normalized_dcg(preds: Tensor, target: Tensor, k: int = None) -> Tensor:
def retrieval_normalized_dcg(preds: Tensor, target: Tensor, k: Optional[int] = None) -> Tensor:
"""
Computes Normalized Discounted Cumulative Gain (for information retrieval), as explained
`here <https://en.wikipedia.org/wiki/Discounted_cumulative_gain>`__.
Expand Down
4 changes: 3 additions & 1 deletion torchmetrics/functional/retrieval/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional

import torch
from torch import Tensor, tensor

from torchmetrics.utilities.checks import _check_retrieval_functional_inputs


def retrieval_precision(preds: Tensor, target: Tensor, k: int = None) -> Tensor:
def retrieval_precision(preds: Tensor, target: Tensor, k: Optional[int] = None) -> Tensor:
"""
Computes the precision metric (for information retrieval),
as explained `here <https://en.wikipedia.org/wiki/Precision_and_recall#Precision>`__.
Expand Down
4 changes: 3 additions & 1 deletion torchmetrics/functional/retrieval/recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional

import torch
from torch import Tensor, tensor

from torchmetrics.utilities.checks import _check_retrieval_functional_inputs


def retrieval_recall(preds: Tensor, target: Tensor, k: int = None) -> Tensor:
def retrieval_recall(preds: Tensor, target: Tensor, k: Optional[int] = None) -> Tensor:
"""
Computes the recall metric (for information retrieval),
as explained `here <https://en.wikipedia.org/wiki/Precision_and_recall#Recall>`__.
Expand Down
16 changes: 8 additions & 8 deletions torchmetrics/retrieval/retrieval_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,18 +82,18 @@ def __init__(

empty_target_action_options = ('error', 'skip', 'neg', 'pos')
if empty_target_action not in empty_target_action_options:
raise ValueError(f"`empty_target_action` received a wrong value `{empty_target_action}`.")
raise ValueError(f"Argument `empty_target_action` received a wrong value `{empty_target_action}`.")

self.empty_target_action = empty_target_action

self.add_state("indexes", default=[], dist_reduce_fx=None)
self.add_state("preds", default=[], dist_reduce_fx=None)
self.add_state("target", default=[], dist_reduce_fx=None)

def update(self, preds: Tensor, target: Tensor, indexes: Tensor = None) -> None:
def update(self, preds: Tensor, target: Tensor, indexes: Tensor) -> None: # type: ignore
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@maximsch2 any suggestion how to deal with this

error: Signature of "update" incompatible with supertype "Metric"  [override]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Borda I approved since everything else looks fine but let's not have this PR merged before this is resolved (this is probably the main issue of metrics typing right now :) )

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets merge it and address this specific case in separate PR

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@justusschock @SkafteNicki @Borda : Since the metric base class uses *args, **kwargs this is currently unavaoidable. Is there a suppression we can add in this case?

Alternatively, why we need to specify *args, **kwargs in the metric base class? This is entirely controlled by the concrete implementation

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

""" Check shape, check and convert dtypes, flatten and add to accumulators. """
if indexes is None:
raise ValueError("`indexes` cannot be None")
raise ValueError("Argument `indexes` cannot be None")

indexes, preds, target = _check_retrieval_inputs(indexes, preds, target)

Expand All @@ -103,10 +103,10 @@ def update(self, preds: Tensor, target: Tensor, indexes: Tensor = None) -> None:

def compute(self) -> Tensor:
"""
First concat state `indexes`, `preds` and `target` since they were stored as lists. After that,
compute list of groups that will help in keeping together predictions about the same query.
Finally, for each group compute the `_metric` if the number of positive targets is at least
1, otherwise behave as specified by `self.empty_target_action`.
First concat state ``indexes``, ``preds`` and ``target`` since they were stored as lists.
After that, compute list of groups that will help in keeping together predictions about the same query.
Finally, for each group compute the ``_metric`` if the number of positive targets is at least
1, otherwise behave as specified by ``self.empty_target_action``.
"""
indexes = torch.cat(self.indexes, dim=0)
preds = torch.cat(self.preds, dim=0)
Expand All @@ -127,7 +127,7 @@ def compute(self) -> Tensor:
elif self.empty_target_action == 'neg':
res.append(tensor(0.0))
else:
# ensure list containt only float tensors
# ensure list contains only float tensors
res.append(self._metric(mini_preds, mini_target))

return torch.stack([x.to(preds) for x in res]).mean() if res else tensor(0.0).to(preds)
Expand Down