diff --git a/src/transformers/models/mobilebert/modeling_mobilebert.py b/src/transformers/models/mobilebert/modeling_mobilebert.py index acf9607a7367..0558a45a9535 100644 --- a/src/transformers/models/mobilebert/modeling_mobilebert.py +++ b/src/transformers/models/mobilebert/modeling_mobilebert.py @@ -152,6 +152,23 @@ def forward(self, input_tensor): NORM2FN = {"layer_norm": nn.LayerNorm, "no_norm": NoNorm} +class QATEmbeddingTransformation(nn.Module): + def __init__(self, embedded_input_size, hidden_size): + super().__init__() + + # Behaves like normal Linear module unless a SparseML QuantizationModifier + # is initialized. + # When initialized, does not quantize inputs. + # Only weights are quantized (inputs come quantized from embeddings) + self.linear = nn.Linear(embedded_input_size, hidden_size) + self.wrap_qat = True + self.qat_wrapper_kwargs = { + "num_inputs": 0, + "num_outputs": 1, + } + + def forward(self, x: torch.Tensor): + return self.linear(x) class MobileBertEmbeddings(nn.Module): """Construct the embeddings from word, position and token_type embeddings.""" @@ -168,7 +185,7 @@ def __init__(self, config): embed_dim_multiplier = 3 if self.trigram_input else 1 embedded_input_size = self.embedding_size * embed_dim_multiplier - self.embedding_transformation = nn.Linear(embedded_input_size, config.hidden_size) + self.embedding_transformation = QATEmbeddingTransformation(embedded_input_size, config.hidden_size) self.LayerNorm = NORM2FN[config.normalization_type](config.hidden_size) self.dropout = nn.Dropout(config.hidden_dropout_prob)