Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
updated PR
Browse files Browse the repository at this point in the history
  • Loading branch information
ananyahjha93 committed Jul 12, 2021
1 parent 8565a49 commit 9f36733
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 23 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for (input, target) style datasets (e.g. torchvision) to the from_datasets method ([#552](https://github.com/PyTorchLightning/lightning-flash/pull/552))
- Added support for `from_csv` and `from_data_frame` to `ImageClassificationData` ([#556](https://github.com/PyTorchLightning/lightning-flash/pull/556))

- Added simclr, swav, barlow-twins pretrained weights for resnet50 backbone in ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560))
- Added SimCLR, SwAV, Barlow-twins pretrained weights for resnet50 backbone in ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560))

### Changed

Expand Down
39 changes: 18 additions & 21 deletions flash/image/backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,17 +92,14 @@ def _fn_resnet(model_name: str,
pretrained: Union[bool, str] = True,
weights_paths: dict = {"supervised": None}) -> Tuple[nn.Module, int]:
# load according to pretrained if a bool is specified, else set to False
if isinstance(pretrained, bool):
pretrained_flag = pretrained
else:
pretrained_flag = False
pretrained_flag = (pretrained and isinstance(pretrained, bool)) or (pretrained == "supervised")

model: nn.Module = getattr(torchvision.models, model_name, None)(pretrained_flag)
backbone = nn.Sequential(*list(model.children())[:-2])
num_features = model.fc.in_features

model_weights = None
if isinstance(pretrained, str):
if not pretrained_flag and isinstance(pretrained, str):
if pretrained in weights_paths:
device = next(model.parameters()).get_device()
model_weights = load_state_dict_from_url(
Expand All @@ -118,9 +115,11 @@ def _fn_resnet(model_name: str,
for (key, val) in model_weights.items()
}
else:
raise KeyError('Model state dict loaded from unrecognized url/path.')
raise KeyError('Unrecognized state dict. Logic for loading the current state dict missing.')
else:
raise KeyError('Unrecognized pretrained model.')
raise KeyError(
"Requested weights for {0} not available,"
" choose from one of {1}".format(model_name, list(weights_paths.keys())))

return backbone, num_features

Expand All @@ -134,24 +133,22 @@ def _fn_resnet_fpn(
return backbone, 256

for model_name in RESNET_MODELS:
clf_kwargs = dict(
fn=catch_url_error(partial(_fn_resnet, model_name=model_name)),
name=model_name,
namespace="vision",
package="torchvision",
type="resnet",
weights_paths={"supervised": None}
)

if model_name == 'resnet50':
IMAGE_CLASSIFIER_BACKBONES(
clf_kwargs.update(dict(
fn=catch_url_error(partial(_fn_resnet, model_name=model_name, weights_paths=RESNET50_WEIGHTS_PATHS)),
name=model_name,
namespace="vision",
package="multiple",
type="resnet",
weights_paths=RESNET50_WEIGHTS_PATHS
)
else:
IMAGE_CLASSIFIER_BACKBONES(
fn=catch_url_error(partial(_fn_resnet, model_name=model_name)),
name=model_name,
namespace="vision",
package="torchvision",
type="resnet",
weights_paths={"supervised": None}
weights_paths=RESNET50_WEIGHTS_PATHS)
)
IMAGE_CLASSIFIER_BACKBONES(**clf_kwargs)

OBJ_DETECTION_BACKBONES(
fn=catch_url_error(partial(_fn_resnet_fpn, model_name)),
Expand Down
4 changes: 3 additions & 1 deletion flash_examples/image_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
)

# 2. Build the task
model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes)
print(ImageClassifier.available_pretrained_weights('resnet50'))
model = ImageClassifier(backbone="resnet50", pretrained='abcd', num_classes=datamodule.num_classes)
exit(-1)

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=3)
Expand Down

0 comments on commit 9f36733

Please sign in to comment.