Skip to content
This repository has been archived by the owner on Nov 11, 2024. It is now read-only.

Commit

Permalink
Adds PAN model architecture and four new encoder types
Browse files Browse the repository at this point in the history
  • Loading branch information
OllyK committed Nov 25, 2022
1 parent 530fb3b commit f4a7746
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 2 deletions.
23 changes: 23 additions & 0 deletions tests/test_model_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
utils.ModelType.DEEPLABV3_PLUS,
utils.ModelType.MA_NET,
utils.ModelType.LINKNET,
utils.ModelType.PAN,
],
)
def test_create_model_on_device(binary_model_struc_dict, model_type):
Expand All @@ -29,6 +30,28 @@ def test_create_model_on_device(binary_model_struc_dict, model_type):
assert device.index == 0


@pytest.mark.gpu
@pytest.mark.parametrize(
"encoder_type",
[
"resnet34",
"resnet50",
"resnext50_32x4d",
"efficientnet-b3",
"efficientnet-b4",
"timm-resnest50d",
"timm-resnest101e",
],
)
def test_create_model_on_device_encoders(binary_model_struc_dict, encoder_type):
binary_model_struc_dict["encoder_name"] = encoder_type
model = create_model_on_device(0, binary_model_struc_dict)
assert isinstance(model, torch.nn.Module)
device = next(model.parameters()).device
assert device.type == "cuda"
assert device.index == 0


@pytest.mark.gpu
def test_create_model_from_file(model_path):
model, classes, codes = create_model_from_file(model_path)
Expand Down
5 changes: 3 additions & 2 deletions volseg-settings/2d_model_train_settings.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,12 @@ plot_lr_graph: False # Set to True if gnuplot is installed and a terminal plot o
model:
# Choose type of segmentation model from the list of those tested so far
# ["U_Net", "U_Net_Plus_plus", "FPN", "DeepLabV3", "DeepLabV3_Plus"]
# "MA_Net", "Linknet"]
# "MA_Net", "Linknet", "PAN"]
type: "U_Net"
# For more details on encoder types please see smp.readthedocs.io
# choose encoder, those tested so far include the following:
# ["resnet34", "resnet50", "resnext50_32x4d"]
# ["resnet34", "resnet50", "resnext50_32x4d", "efficientnet-b3","efficientnet-b4",
# "timm-resnest50d"*, "timm-resnest101e"*,] *Encoders with asterisk not compatible with PAN.
encoder_name: "resnet34"
# use `imagenet` pre-trained weights for encoder initialization
encoder_weights: "imagenet"
Expand Down
2 changes: 2 additions & 0 deletions volume_segmantics/model/model_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ def create_model_on_device(device_num: int, model_struc_dict: dict) -> torch.nn.
logging.info(f"Sending the MA-Net model to device {device_num}")
elif model_type == utils.ModelType.LINKNET:
model = smp.Linknet(**struct_dict_copy)
elif model_type == utils.ModelType.PAN:
model = smp.PAN(**struct_dict_copy)
logging.info(f"Sending the Linknet model to device {device_num}")
return model.to(device_num)

Expand Down
1 change: 1 addition & 0 deletions volume_segmantics/utilities/base_data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class ModelType(Enum):
DEEPLABV3_PLUS = 5
MA_NET = 6
LINKNET = 7
PAN = 8


def create_enum_from_setting(setting_str, enum):
Expand Down

0 comments on commit f4a7746

Please sign in to comment.