diff --git a/src/transformers/models/cvt/modeling_cvt.py b/src/transformers/models/cvt/modeling_cvt.py index afce29d6bd72..a76e681c3da3 100644 --- a/src/transformers/models/cvt/modeling_cvt.py +++ b/src/transformers/models/cvt/modeling_cvt.py @@ -451,7 +451,11 @@ def __init__(self, config, stage): self.config = config self.stage = stage if self.config.cls_token[self.stage]: - self.cls_token = nn.Parameter(torch.zeros(1, 1, self.config.embed_dim[-1])) + self.cls_token = nn.Parameter( + nn.init.trunc_normal_( + torch.zeros(1, 1, self.config.embed_dim[-1]), mean=0.0, std=config.initializer_range + ) + ) self.embedding = CvtEmbeddings( patch_size=config.patch_sizes[self.stage], @@ -547,9 +551,7 @@ class CvtPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.data = nn.init.trunc_normal_(module.weight.data, mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.LayerNorm):