diff --git a/hubconf.py b/hubconf.py index 01f4eba08c81..7e69ac7f0211 100644 --- a/hubconf.py +++ b/hubconf.py @@ -41,7 +41,8 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo name = Path(name) path = name.with_suffix('.pt') if name.suffix == '' and not name.is_dir() else name # checkpoint path try: - device = select_device(('0' if torch.cuda.is_available() else 'cpu') if device is None else device) + device = select_device(('0' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available( + ) else 'cpu') if device is None else device) if pretrained and channels == 3 and classes == 80: model = DetectMultiBackend(path, device=device) # download/load FP32 model