Skip to content

Commit

Permalink
Fix VGG model
Browse files Browse the repository at this point in the history
We should be using cont_set_at_key_chain function to update the weights instead of directly assigning the .v attribute which doesn't update the weights of individual layers
  • Loading branch information
hello-fri-end authored Oct 11, 2023
1 parent b8216d8 commit 3edcdff
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions ivy_models/vgg/vgg.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def vgg11(pretrained=True, data_format="NHWC"):
w_clean = ivy_models.helpers.load_torch_weights(
url, model, custom_mapping=_vgg_torch_weights_mapping
)
model.v = w_clean
model.v.cont_set_at_key_chains(w_clean, inplace=True)
return model


Expand All @@ -123,7 +123,7 @@ def vgg11_bn(pretrained=True, data_format="NHWC"):
w_clean = ivy_models.helpers.load_torch_weights(
url, model, custom_mapping=_vgg_torch_weights_mapping
)
model.v = w_clean
model.v.cont_set_at_key_chains(w_clean, inplace=True)
return model


Expand All @@ -135,7 +135,7 @@ def vgg13(pretrained=True, data_format="NHWC"):
w_clean = ivy_models.helpers.load_torch_weights(
url, model, custom_mapping=_vgg_torch_weights_mapping
)
model.v = w_clean
model.v.cont_set_at_key_chains(w_clean, inplace=True)
return model


Expand All @@ -147,7 +147,7 @@ def vgg13_bn(pretrained=True, data_format="NHWC"):
w_clean = ivy_models.helpers.load_torch_weights(
url, model, custom_mapping=_vgg_torch_weights_mapping
)
model.v = w_clean
model.v.cont_set_at_key_chains(w_clean, inplace=True)
return model


Expand All @@ -159,7 +159,7 @@ def vgg16(pretrained=True, data_format="NHWC"):
w_clean = ivy_models.helpers.load_torch_weights(
url, model, custom_mapping=_vgg_torch_weights_mapping
)
model.v = w_clean
model.v.cont_set_at_key_chains(w_clean, inplace=True)
return model


Expand All @@ -171,7 +171,7 @@ def vgg16_bn(pretrained=True, data_format="NHWC"):
w_clean = ivy_models.helpers.load_torch_weights(
url, model, custom_mapping=_vgg_torch_weights_mapping
)
model.v = w_clean
model.v.cont_set_at_key_chains(w_clean, inplace=True)
return model


Expand All @@ -183,7 +183,7 @@ def vgg19(pretrained=True, data_format="NHWC"):
w_clean = ivy_models.helpers.load_torch_weights(
url, model, custom_mapping=_vgg_torch_weights_mapping
)
model.v = w_clean
model.v.cont_set_at_key_chains(w_clean, inplace=True)
return model


Expand All @@ -195,7 +195,7 @@ def vgg19_bn(pretrained=True, data_format="NHWC"):
w_clean = ivy_models.helpers.load_torch_weights(
url, model, custom_mapping=_vgg_torch_weights_mapping
)
model.v = w_clean
model.v.cont_set_at_key_chains(w_clean, inplace=True)
return model


Expand Down

0 comments on commit 3edcdff

Please sign in to comment.