Skip to content

Commit

Permalink
fix: fixed the issues with the alexnet, unet and resnet models due to…
Browse files Browse the repository at this point in the history
… changes in the ivy.Module and transpose convolutions
  • Loading branch information
vedpatwardhan committed Jan 18, 2024
1 parent 3515ddc commit 3cdfe0c
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 10 deletions.
2 changes: 1 addition & 1 deletion ivy_models/alexnet/alexnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,5 +85,5 @@ def alexnet(pretrained=True, num_classes=1000, dropout=0, data_format="NCHW"):
w_clean = ivy_models.helpers.load_torch_weights(
url, model, custom_mapping=_alexnet_torch_weights_mapping
)
model.v = w_clean
model._v = w_clean
return model
10 changes: 5 additions & 5 deletions ivy_models/resnet/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def resnet_18(pretrained=True):
raw_keys_to_prune=["num_batches_tracked"],
custom_mapping=_resnet_torch_weights_mapping,
)
model.v = w_clean
model._v = w_clean
return model


Expand All @@ -197,7 +197,7 @@ def resnet_34(pretrained=True):
raw_keys_to_prune=["num_batches_tracked"],
custom_mapping=_resnet_torch_weights_mapping,
)
model.v = w_clean
model._v = w_clean
return model


Expand All @@ -212,7 +212,7 @@ def resnet_50(pretrained=True):
raw_keys_to_prune=["num_batches_tracked"],
custom_mapping=_resnet_torch_weights_mapping,
)
model.v = w_clean
model._v = w_clean
return model


Expand All @@ -227,7 +227,7 @@ def resnet_101(pretrained=True):
raw_keys_to_prune=["num_batches_tracked"],
custom_mapping=_resnet_torch_weights_mapping,
)
model.v = w_clean
model._v = w_clean
return model


Expand All @@ -242,5 +242,5 @@ def resnet_152(pretrained=True):
raw_keys_to_prune=["num_batches_tracked"],
custom_mapping=_resnet_torch_weights_mapping,
)
model.v = w_clean
model._v = w_clean
return model
9 changes: 5 additions & 4 deletions ivy_models/unet/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from ivy_models.unet.layers import UNetDoubleConv, UNetDown, UNetOutConv, UNetUp

import builtins
import re
from ivy_models.helpers import load_torch_weights
from ivy_models.base import BaseSpec, BaseModel

Expand Down Expand Up @@ -66,9 +67,9 @@ def _unet_torch_weights_mapping(old_key, new_key):
"conv/weight",
]

if "up/weight" in old_key:
new_mapping = {"key_chain": new_key, "pattern": "b c h w -> h w b c"}
elif builtins.any([kc in old_key for kc in W_KEY]):
if builtins.any([kc in old_key for kc in W_KEY]) or re.match(
"up\d/((conv/double_conv/0)|(up))/weight", old_key
):
new_mapping = {"key_chain": new_key, "pattern": "b c h w -> h w c b"}
elif "conv/bias" in old_key or "up/bias" in old_key:
new_mapping = {"key_chain": new_key, "pattern": "h -> 1 1 1 h"}
Expand All @@ -86,7 +87,7 @@ def unet_carvana(n_channels=3, n_classes=2, v=None, pretrained=True):
raw_keys_to_prune=["num_batches_tracked"],
custom_mapping=_unet_torch_weights_mapping,
)
model.v = w_clean
model._v = w_clean
return model


Expand Down

0 comments on commit 3cdfe0c

Please sign in to comment.