-
Notifications
You must be signed in to change notification settings - Fork 421
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
false device inconsistent runtime error #602
Comments
it seems as common confusion, please see #531 |
Thanks for the reply. I am using PL module and actually I think this is related to class FineTuneModule(LightningModule):
def __init__(
self,
arch: str,
**kwargs,
):
super().__init__()
self.save_hyperparameters()
# metrics
mc = MetricCollection({
"accuracy": Accuracy(threshold=0.0),
"recall": Recall(threshold=0.0, num_classes=7, average='macro'),
"precision": Precision(threshold=0.0, num_classes=7, average='macro'),
"f1": F1(threshold=0.0, num_classes=7, average='macro'),
"macro_auc": AUROC(num_classes=7, average='macro'),
"weighted_auc": AUROC(num_classes=7, average='weighted')
})
self.metrics: ModuleDict[str, MetricCollection] = ModuleDict({
f"{phase}_metric": mc.clone()
for phase in ["train", "valid", "test"]
}) and in the # phase is train or valid or test
metrics = self.metrics[f"{phase}_metric"]
metrics(output['prob'], output['label']) And I think
Sorry I wasn't make this clear. |
Reopening as I can confirm there seems to be a bug in the code here. |
🐛 Bug
When using weighted AUC even if the input tensors and metric are in the same device.
To Reproduce
would result in
Based on this erorr message, I assume this because in the intermediate step some cpu tensors are created? But I am not too sure.
Environment
conda
,pip
, source): pipTorchmetrics version 0.6.0
Torch Lightning 1.4.9
The text was updated successfully, but these errors were encountered: