You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
def load(self, model_file, pretrain_file):
""" load saved model or pretrained transformer (a part of model) """
if model_file:
print('Loading the model from', model_file)
self.model.load_state_dict(torch.load(model_file))
elif pretrain_file: # use pretrained transformer
print('Loading the pretrained model from', pretrain_file)
if pretrain_file.endswith('.ckpt'): # checkpoint file in tensorflow
checkpoint.load_model(self.model.transformer, pretrain_file)
elif pretrain_file.endswith('.pt'): # pretrain model file in pytorch
self.model.transformer.load_state_dict(
{key[12:]: value
for key, value in torch.load(pretrain_file).items()
if key.startswith('transformer')}
) # load only transformer parts
Could I kindly ask that what is the meaning of key[12:]: value when you load a pretrained_model? Just want to keep the last layer? Thanks, hope for your reply.
The text was updated successfully, but these errors were encountered:
@dhlee347 thanks for your reply. I have one more question, if I change the number of BERT layers from 12 to 6, should I change the key[12:] to key[6:]?
Could I kindly ask that what is the meaning of key[12:]: value when you load a pretrained_model? Just want to keep the last layer? Thanks, hope for your reply.
The text was updated successfully, but these errors were encountered: