Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A quesiton about load_state_dict #7

Open
shenyehui opened this issue Mar 28, 2024 · 5 comments
Open

A quesiton about load_state_dict #7

shenyehui opened this issue Mar 28, 2024 · 5 comments

Comments

@shenyehui
Copy link

pretrained_weights_path = 'lightvit_tiny_78.7.ckpt'
pretrained_state_dict = torch.load(pretrained_weights_path)
lightvit = lightvit_tiny(pretrained=False)
lightvit.load_state_dict(pretrained_state_dict)

I use about code to use pretrained lightvit_tiny model, but it's useless, how can i pretrained lightvit_tiny model correctly?

@hunto
Copy link
Owner

hunto commented Mar 28, 2024

you may use

lightvit.load_state_dict(pretrained_state_dict["state_dict"])

@hunto
Copy link
Owner

hunto commented Mar 28, 2024

Can you provide more detailed error logs?

@shenyehui
Copy link
Author

shenyehui commented Mar 28, 2024

Can you provide more detailed error logs?

Thank you! I successfully loaded the pre-trained model, but this error occurs when I use the following code: to see the form of the tensor output before the pooling layer:
pretrained_weights_path = 'lightvit_tiny_78.7.ckpt'
pretrained_state_dict = torch.load(pretrained_weights_path)
lightvit = lightvit_tiny(pretrained=False)
lightvit.load_state_dict(pretrained_state_dict["state_dict"])
featureslightViT = list(lightvit.children())[:-1]
self.backbone = nn.Sequential(*featureslightViT)
TypeError: forward() missing 2 required positional arguments: 'H' and 'W', my inputs are inputs = torch.rand((1, 3, 224, 224))

@hunto
Copy link
Owner

hunto commented Mar 28, 2024

You cannot simply wrap the children to nn.Sequential since some blocks require H and W as inputs, please refer to forward_features:
https://github.com/hunto/image_classification_sota/blob/36539b63cc8b851bd3fc93251bba60528813bb36/lib/models/lightvit.py#L384

@shenyehui
Copy link
Author

You cannot simply wrap the children to nn.Sequential since some blocks require H and W as inputs, please refer to forward_features: https://github.com/hunto/image_classification_sota/blob/36539b63cc8b851bd3fc93251bba60528813bb36/lib/models/lightvit.py#L384

I see in the diagram of the paper that the average pooling layer is in the HEAD section, does looking at this code mean that the pooling operation is not necessarily performed?
self.head = nn.Linear(neck_dim, num_classes) if num_classes > 0 else nn.Identity()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants