From 210bb706bff93830dabfdddd3321e418891d6bdc Mon Sep 17 00:00:00 2001 From: QuentinCouton <62642291+QuentinCouton@users.noreply.github.com> Date: Tue, 14 Jun 2022 12:29:21 +0200 Subject: [PATCH 1/2] Update hubconf.py --- hubconf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hubconf.py b/hubconf.py index 01f4eba08c81..a1b27508c735 100644 --- a/hubconf.py +++ b/hubconf.py @@ -41,7 +41,7 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo name = Path(name) path = name.with_suffix('.pt') if name.suffix == '' and not name.is_dir() else name # checkpoint path try: - device = select_device(('0' if torch.cuda.is_available() else 'cpu') if device is None else device) + device = select_device(('0' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu') if device is None else device) if pretrained and channels == 3 and classes == 80: model = DetectMultiBackend(path, device=device) # download/load FP32 model From c07cb64df465b121103f7c45f3b66427e6104dbd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 14 Jun 2022 10:30:25 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- hubconf.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/hubconf.py b/hubconf.py index a1b27508c735..7e69ac7f0211 100644 --- a/hubconf.py +++ b/hubconf.py @@ -41,7 +41,8 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo name = Path(name) path = name.with_suffix('.pt') if name.suffix == '' and not name.is_dir() else name # checkpoint path try: - device = select_device(('0' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu') if device is None else device) + device = select_device(('0' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available( + ) else 'cpu') if device is None else device) if pretrained and channels == 3 and classes == 80: model = DetectMultiBackend(path, device=device) # download/load FP32 model