Skip to content
This repository was archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
fix load from pretrained and add resnet 101
Browse files Browse the repository at this point in the history
  • Loading branch information
edgarriba committed Apr 28, 2021
1 parent 409a38a commit 7f8ebab
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 7 deletions.
14 changes: 11 additions & 3 deletions flash/vision/segmentation/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(
num_classes: int,
backbone: Union[str, Tuple[nn.Module, int]] = "torchvision/fcn_resnet50",
backbone_kwargs: Optional[Dict] = None,
pretrained: bool = False,
pretrained: bool = True,
loss_fn: Optional[Callable] = None,
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam,
metrics: Optional[Union[Callable, Mapping, Sequence, None]] = None,
Expand Down Expand Up @@ -104,6 +104,14 @@ def forward(self, x) -> torch.Tensor:


@SemanticSegmentation.backbones(name="torchvision/fcn_resnet50")
def fn(pretrained: bool, num_classes: int) -> nn.Module:
model: nn.Module = torchvision.models.segmentation.fcn_resnet50(pretrained=pretrained, num_classes=num_classes)
def load_torchvision_fcn_resnet50(pretrained: bool, num_classes: int) -> nn.Module:
model = torchvision.models.segmentation.fcn_resnet50(pretrained=pretrained)
model.classifier[-1] = nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1))
return model


@SemanticSegmentation.backbones(name="torchvision/fcn_resnet101")
def load_torchvision_fcn_resnet101(pretrained: bool, num_classes: int) -> nn.Module:
model = torchvision.models.segmentation.fcn_resnet101(pretrained=pretrained)
model.classifier[-1] = nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1))
return model
9 changes: 5 additions & 4 deletions flash_examples/finetuning/semantic_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,14 @@ def load_data(data_root: str = 'data/') -> Tuple[List[str], List[str]]:
batch_size=4,
val_split=0.3, # TODO: this needs to be implemented
image_size=(300, 400), # (600, 800)
num_workers=0,
)

# 2.2 Visualise the samples
labels_map = SegmentationLabels.create_random_labels_map(num_classes=21)
datamodule.set_labels_map(labels_map)
datamodule.show_train_batch("load_sample")
datamodule.show_train_batch("post_tensor_transform")
#datamodule.show_train_batch("load_sample")
#datamodule.show_train_batch("post_tensor_transform")

# 3. Build the model
model = SemanticSegmentation(
Expand All @@ -69,9 +70,9 @@ def load_data(data_root: str = 'data/') -> Tuple[List[str], List[str]]:

# 4. Create the trainer.
trainer = flash.Trainer(
max_epochs=5,
max_epochs=1,
gpus=1,
# precision=16, # why slower ? :)
#precision=16, # why slower ? :)
)

# 5. Train the model
Expand Down
1 change: 1 addition & 0 deletions tests/vision/segmentation/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def test_forward(num_classes, img_shape):
"backbone",
[
"torchvision/fcn_resnet50",
"torchvision/fcn_resnet101",
],
)
def test_init_train(tmpdir, backbone):
Expand Down

0 comments on commit 7f8ebab

Please sign in to comment.