Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Catch URLError (#237)
Browse files Browse the repository at this point in the history
* Catch URLError

* Updates

* Update CHANGELOG.md

* Update CHANGELOG.md

* Fix error
  • Loading branch information
ethanwharris authored Apr 22, 2021
1 parent 1f9e151 commit c28a22e
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 10 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed classification softmax ([#169](https://github.com/PyTorchLightning/lightning-flash/pull/169))

- Fixed a bug where loading from a local checkpoint that had `pretrained=True` without an internet connection would sometimes raise an error ([#237](https://github.com/PyTorchLightning/lightning-flash/pull/237))

### Removed


Expand Down
36 changes: 29 additions & 7 deletions flash/vision/backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
import os
import urllib.error
import warnings
from functools import partial
from typing import Tuple

from pytorch_lightning import LightningModule
from pytorch_lightning.utilities import _BOLTS_AVAILABLE
from pytorch_lightning.utilities import _BOLTS_AVAILABLE, rank_zero_warn
from torch import nn as nn
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone

Expand Down Expand Up @@ -51,6 +52,24 @@
OBJ_DETECTION_BACKBONES = FlashRegistry("backbones")


def catch_url_error(fn):

@functools.wraps(fn)
def wrapper(pretrained=False, **kwargs):
try:
return fn(pretrained=pretrained, **kwargs)
except urllib.error.URLError:
result = fn(pretrained=False, **kwargs)
rank_zero_warn(
"Failed to download pretrained weights for the selected backbone. The backbone has been created with"
" `pretrained=False` instead. If you are loading from a local checkpoint, this warning can be safely"
" ignored.", UserWarning
)
return result

return wrapper


@IMAGE_CLASSIFIER_BACKBONES(name="simclr-imagenet", namespace="vision", package="bolts")
def load_simclr_imagenet(path_or_url: str = f"{ROOT_S3_BUCKET}/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt", **_):
simclr: LightningModule = SimCLR.load_from_checkpoint(path_or_url, strict=False)
Expand Down Expand Up @@ -83,7 +102,7 @@ def _fn_mobilenet_vgg(model_name: str, pretrained: bool = True) -> Tuple[nn.Modu
_type = "mobilenet" if model_name in MOBILENET_MODELS else "vgg"

IMAGE_CLASSIFIER_BACKBONES(
fn=partial(_fn_mobilenet_vgg, model_name),
fn=catch_url_error(partial(_fn_mobilenet_vgg, model_name)),
name=model_name,
namespace="vision",
package="torchvision",
Expand All @@ -99,7 +118,7 @@ def _fn_resnet(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int
return backbone, num_features

IMAGE_CLASSIFIER_BACKBONES(
fn=partial(_fn_resnet, model_name),
fn=catch_url_error(partial(_fn_resnet, model_name)),
name=model_name,
namespace="vision",
package="torchvision",
Expand All @@ -118,7 +137,10 @@ def _fn_resnet_fpn(
return backbone, 256

OBJ_DETECTION_BACKBONES(
fn=partial(_fn_resnet_fpn, model_name), name=model_name, package="torchvision", type="resnet-fpn"
fn=catch_url_error(partial(_fn_resnet_fpn, model_name)),
name=model_name,
package="torchvision",
type="resnet-fpn"
)

for model_name in DENSENET_MODELS:
Expand All @@ -130,7 +152,7 @@ def _fn_densenet(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, i
return backbone, num_features

IMAGE_CLASSIFIER_BACKBONES(
fn=partial(_fn_densenet, model_name),
fn=catch_url_error(partial(_fn_densenet, model_name)),
name=model_name,
namespace="vision",
package="torchvision",
Expand All @@ -156,5 +178,5 @@ def _fn_timm(
return backbone, num_features

IMAGE_CLASSIFIER_BACKBONES(
fn=partial(_fn_timm, model_name), name=model_name, namespace="vision", package="timm"
fn=catch_url_error(partial(_fn_timm, model_name)), name=model_name, namespace="vision", package="timm"
)
4 changes: 2 additions & 2 deletions flash/vision/detection/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,8 @@ def get_model(
)
else:
backbone_model, num_features = ObjectDetector.backbones.get(backbone)(
pretrained_backbone,
trainable_backbone_layers,
pretrained=pretrained_backbone,
trainable_layers=trainable_backbone_layers,
**kwargs,
)
backbone_model.out_channels = num_features
Expand Down
14 changes: 13 additions & 1 deletion tests/vision/test_backbones.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import urllib.error

import pytest
from pytorch_lightning.utilities import _BOLTS_AVAILABLE, _TORCHVISION_AVAILABLE

from flash.utils.imports import _TIMM_AVAILABLE
from flash.vision.backbones import IMAGE_CLASSIFIER_BACKBONES
from flash.vision.backbones import catch_url_error, IMAGE_CLASSIFIER_BACKBONES


@pytest.mark.parametrize(["backbone", "expected_num_features"], [
Expand All @@ -17,3 +19,13 @@ def test_image_classifier_backbones_registry(backbone, expected_num_features):
backbone_model, num_features = backbone_fn(pretrained=False)
assert backbone_model
assert num_features == expected_num_features


def test_pretrained_backbones_catch_url_error():

def raise_error_if_pretrained(pretrained=False):
if pretrained:
raise urllib.error.URLError('Test error')

with pytest.warns(UserWarning, match="Failed to download pretrained weights"):
catch_url_error(raise_error_if_pretrained)(pretrained=True)

0 comments on commit c28a22e

Please sign in to comment.