Expected all tensors to be on the same device #341
-
QuestionExpected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! from torchmetrics.classification import F1
def validation_step(self, batch, batch_idx):
x, y = batch
logits = self.forward(x)
loss = self.loss_fn(logits, y)
preds = torch.argmax(logits, dim=1)
acc = accuracy(preds, y)
f1 = F1(num_classes=self.args.num_classes)(preds, y)
self.log('val_loss', loss, prog_bar=True)
self.log('val_acc', acc, prog_bar=True)
self.log('val_f1', f1, prog_bar=True)
return loss
model = ImageNetClassify(args)
earlystop = EarlyStopping("val_acc",patience=5)
checkpoint_callback = ModelCheckpoint(
monitor='val_acc', # 需要监控的指标
dirpath='./ntt/alphamind', #权重保存路径
mode = 'max',# 监控变量的最小值
verbose = True,
)
trainer = Trainer(logger=tb_logger,
# weights_summary='full',
progress_bar_refresh_rate=1,
gpus=1,
auto_select_gpus=True,
log_gpu_memory='all',
benchmark=True,
# max_epochs=30,
num_sanity_val_steps=2,
auto_scale_batch_size=True,
auto_lr_find=True,
callbacks=[earlystop,checkpoint_callback]) To ReproduceSteps to reproduce the behavior:
Expected behaviorEnvironment
|
Beta Was this translation helpful? Give feedback.
Answered by
SkafteNicki
Jul 1, 2021
Replies: 1 comment
-
Hi @jaffe-fly, from torchmetrics.functional import f1
def validation_step(self, batch, batch_idx):
x, y = batch
logits = self.forward(x)
loss = self.loss_fn(logits, y)
preds = torch.argmax(logits, dim=1)
acc = accuracy(preds, y)
f1 = f1(preds, y, num_classes=self.args.num_classes)
self.log('val_loss', loss, prog_bar=True)
self.log('val_acc', acc, prog_bar=True)
self.log('val_f1', f1, prog_bar=True)
return loss |
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
Borda
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi @jaffe-fly,
Modular metrics are
nn.Module
and therefore needs to be moved to the same device as the input. This is done automatically if you define them in the__init__
method of your model. You can read more here: https://torchmetrics.readthedocs.io/en/latest/pages/overview.html#metrics-and-devicesHowever, an maybe easier fix in your case is you could just use the functional version of the
F1
metric: