diff --git a/src/transformers/pytorch_utils.py b/src/transformers/pytorch_utils.py index 50f7d6d862f7..4c74a04d4f34 100644 --- a/src/transformers/pytorch_utils.py +++ b/src/transformers/pytorch_utils.py @@ -96,6 +96,7 @@ class Conv1D(nn.Module): def __init__(self, nf, nx): super().__init__() self.nf = nf + self.nx = nx self.weight = nn.Parameter(torch.empty(nx, nf)) self.bias = nn.Parameter(torch.zeros(nf)) nn.init.normal_(self.weight, std=0.02)