diff --git a/models/common.py b/models/common.py index d59ee976c9f7..5119881e683f 100644 --- a/models/common.py +++ b/models/common.py @@ -531,7 +531,7 @@ def forward(self, imgs, size=640, augment=False, profile=False): # multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images t = [time_sync()] - p = next(self.model.parameters()) if self.pt else torch.zeros(1) # for device and type + p = next(self.model.parameters()) if self.pt else torch.zeros(1, device=self.model.device) # for device, type autocast = self.amp and (p.device.type != 'cpu') # Automatic Mixed Precision (AMP) inference if isinstance(imgs, torch.Tensor): # torch with amp.autocast(autocast):