-
Notifications
You must be signed in to change notification settings - Fork 423
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MetricWrapper for Target Binarization #2371
Comments
Hi! thanks for your contribution!, great first issue! |
Hi @lgienapp, thanks for opening this issue.
how does that sound? |
Establishing a general class sounds good. Just for clarification: the general In either case, I would propose an implementation like this (subclassing wrappers here), assuming that only positional params like class MetricInputTransformer(WrapperMetric):
def __init__(self, wrapped_metric: Union[Metric, MetricCollection], **kwargs: Any):
super().__init__(**kwargs)
self.wrapped_metric = wrapped_metric
def transform(self, *args) -> Tuple[torch.Tensor]:
raise NotImplementedError
def update(self, *args, **kwargs: Any) -> None:
self.wrapped_metric.update(*self.transform(*args), **kwargs)
def compute(self) -> Any:
return self.wrapped_metric.compute()
def forward(self, *args, **kwargs: Any) -> Any:
self.wrapped_metric.forward(*self.transform(*args), **kwargs) |
🚀 Feature
Add a
TargetBinarizationWrapper
that cast continuous labels to binary labels given a threshold.Motivation
Evaluating two metrics that require different label formats (e.g., one binary, the other continuous) is cumbersome since it requires setting up two different evaluation stacks where one is fed with binarized label data and the other is fed the original continuous data. This leads to code duplication. Also, persisting binarized labels into the dataset in scenarios where a metric requires different input than what is given in the ground-truth data diminishes code clarity w.r.t. the evaluation process.
Pitch
A metric wrapper that casts target data to binary targets during the
.update()
and.forward()
methods. Can be applied to either a singleMetric
, or a wholeMetricCollection
.Alternatives
MultiTaskWrapper
is possible, but has two caveats: (1) metrics with a different signature thanupdate(pred, target)
are not supported, and (2) it requires the user to implement the thresholding logic by themselves before feeding it into theMultiTaskWrapper
lambda
that is applied to targets; more flexible, but also requires the user to implement their own logic. I think binarization is a common enough problem intorchmetrics
(since its metrics make a binary vs. non-binary distinction) to warrant its own wrapper.Additional Information
Consider the following example of the desired behaviour:
If simple binarization as in the example is a desired solution, I have all the code needed for a pull request ready and can take on this issue.
The text was updated successfully, but these errors were encountered: