Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

questions for loading the pretrained_model #21

Open
mingbocui opened this issue Nov 12, 2019 · 2 comments
Open

questions for loading the pretrained_model #21

mingbocui opened this issue Nov 12, 2019 · 2 comments

Comments

@mingbocui
Copy link

    def load(self, model_file, pretrain_file):
        """ load saved model or pretrained transformer (a part of model) """
        if model_file:
            print('Loading the model from', model_file)
            self.model.load_state_dict(torch.load(model_file))

        elif pretrain_file: # use pretrained transformer
            print('Loading the pretrained model from', pretrain_file)
            if pretrain_file.endswith('.ckpt'): # checkpoint file in tensorflow
                checkpoint.load_model(self.model.transformer, pretrain_file)
            elif pretrain_file.endswith('.pt'): # pretrain model file in pytorch
                self.model.transformer.load_state_dict(
                    {key[12:]: value
                        for key, value in torch.load(pretrain_file).items()
                        if key.startswith('transformer')}
                ) # load only transformer parts

Could I kindly ask that what is the meaning of key[12:]: value when you load a pretrained_model? Just want to keep the last layer? Thanks, hope for your reply.

@dhlee347
Copy link
Owner

It is because I wanted to load only a transformer part of saved model, not the whole model.

@mingbocui
Copy link
Author

@dhlee347 thanks for your reply. I have one more question, if I change the number of BERT layers from 12 to 6, should I change the key[12:] to key[6:]?

@mingbocui mingbocui reopened this Nov 29, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants