From 7d9ea90fa2adc007d92753a6d1f83845842df37e Mon Sep 17 00:00:00 2001 From: Tuan Nguyen Date: Thu, 7 Apr 2022 02:03:27 -0400 Subject: [PATCH] Using QATMatMul in DistilBERT model class --- .../models/distilbert/modeling_distilbert.py | 26 ++++++++++++++++--- src/transformers/trainer.py | 6 +++++ 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/distilbert/modeling_distilbert.py b/src/transformers/models/distilbert/modeling_distilbert.py index 883a89502b62..248dbfcbbbd7 100755 --- a/src/transformers/models/distilbert/modeling_distilbert.py +++ b/src/transformers/models/distilbert/modeling_distilbert.py @@ -91,6 +91,22 @@ def _create_sinusoidal_embeddings(n_pos, dim, out): out.detach_() +class QATMatMul(nn.Module): + def __init__(self): + super().__init__() + + # behaves like normal torch.matmul unless a SparseML QuantizationModifier + # is initialized + self.wrap_qat = True + self.qat_wrapper_kwargs = { + "num_inputs": 2, + "input_qconfigs": ["asymmetric", "symmetric"], + } + + def forward(self, a: torch.Tensor, b: torch.Tensor): + return torch.matmul(a, b) + + class Embeddings(nn.Module): def __init__(self, config): super().__init__() @@ -153,6 +169,11 @@ def __init__(self, config): self.pruned_heads = set() + # non-parameterized matmuls will behave as normal torch.matmul ops unless + # Quantization-Aware-Training is invoked + self.attention_scores_matmul = QATMatMul() + self.context_layer_matmul = QATMatMul() + def prune_heads(self, heads): attention_head_size = self.dim // self.n_heads if len(heads) == 0: @@ -202,7 +223,7 @@ def unshape(x): v = shape(self.v_lin(value)) # (bs, n_heads, k_length, dim_per_head) q = q / math.sqrt(dim_per_head) # (bs, n_heads, q_length, dim_per_head) - scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length) + scores = self.attention_scores_matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length) mask = (mask == 0).view(mask_reshp).expand_as(scores) # (bs, n_heads, q_length, k_length) scores = scores.masked_fill(mask, -float("inf")) # (bs, n_heads, q_length, k_length) @@ -213,7 +234,7 @@ def unshape(x): if head_mask is not None: weights = weights * head_mask - context = torch.matmul(weights, v) # (bs, n_heads, q_length, dim_per_head) + context = self.context_layer_matmul(weights, v) # (bs, n_heads, q_length, dim_per_head) context = unshape(context) # (bs, q_length, dim) context = self.out_lin(context) # (bs, q_length, dim) @@ -625,7 +646,6 @@ def forward( loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict - dlbrt_output = self.distilbert( input_ids=input_ids, attention_mask=attention_mask, diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index f6449d71cde8..6129fa6af9f9 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2436,6 +2436,12 @@ def evaluation_loop( observed_num_examples = 0 # Main evaluation loop for step, inputs in enumerate(dataloader): + inputs = { + k: inputs[k] + for k in inputs + if k in list(inspect.signature(model.forward).parameters.keys()) + } + # Update the observed num examples observed_batch_size = find_batch_size(inputs) if observed_batch_size is not None: