Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Metrics not on same device as Trainer/model #231

Closed
mandeep-starling opened this issue Aug 29, 2024 · 2 comments
Closed

Metrics not on same device as Trainer/model #231

mandeep-starling opened this issue Aug 29, 2024 · 2 comments

Comments

@mandeep-starling
Copy link

When running on a CUDA device, the following code:

from pytorch_widedeep import Trainer
from torchmetrics.classification import BinaryAUROC, BinaryPrecision, BinaryRecall

trainer = Trainer(model, objective="binary", metrics=[BinaryAUROC(), BinaryPrecision(), BinaryRecall()])

Returns error:

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
...
  File "/opt/conda/lib/python3.10/site-packages/pytorch_widedeep/utils/general_utils.py", line 12, in wrapper
    return func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/pytorch_widedeep/training/trainer.py", line 524, in fit
    train_score, train_loss = self._train_step(
  File "/opt/conda/lib/python3.10/site-packages/pytorch_widedeep/training/trainer.py", line 1000, in _train_step
    score = self._get_score(y_pred, y)
  File "/opt/conda/lib/python3.10/site-packages/pytorch_widedeep/training/trainer.py", line 1049, in _get_score
    score = self.metric(torch.sigmoid(y_pred), y)
  File "/opt/conda/lib/python3.10/site-packages/pytorch_widedeep/metrics.py", line 40, in __call__
    metric.update(y_pred, y_true.int())  # type: ignore[attr-defined]
  File "/opt/conda/lib/python3.10/site-packages/torchmetrics/metric.py", line 486, in wrapped_func
    raise RuntimeError(
RuntimeError: Encountered different devices in metric calculation (see stacktrace for details). This could be due to the metric class not being on the same device as input. Instead of `metric=BinaryPrecision(...)` try to do `metric=BinaryPrecision(...).to(device)` where device corresponds to the device of the input.

I could manually try to do this on my end, by getting the device from here and setting it. But I wonder if this should work out of the box?

I read some discussion in pytorch-lightning here and here where it seems there is a way for metrics to be automatically moved to the same device as the model?

@jrzaurin
Copy link
Owner

@5uperpalo I am traveling. If you have a sec maybe you could have a look?

Otherwise ill have a look in the next coming days

@jrzaurin
Copy link
Owner

Hey @mandeep-starling

there are a number of ways we could tackle this.

One is, as you suggest, getting the device from the Trainer and pass it to the metrics like

trainer = Trainer(model, objective="binary", metrics=[BinaryAUROC().to(device), ...]

I can also set it up internally, but that involves some complications depending whether a user decides to use the nn.Module
lightning metrics or their functional version.

I will think about it. I might branch out and you could try see what you prefer.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants