From abdcc7f11e4a56001a17a05c3005a00c8f5fc068 Mon Sep 17 00:00:00 2001 From: ydshieh Date: Wed, 9 Mar 2022 21:34:26 +0100 Subject: [PATCH] fix --- .../models/deberta_v2/modeling_tf_deberta_v2.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/deberta_v2/modeling_tf_deberta_v2.py b/src/transformers/models/deberta_v2/modeling_tf_deberta_v2.py index 445cb76256bb..f90dcd765e7c 100644 --- a/src/transformers/models/deberta_v2/modeling_tf_deberta_v2.py +++ b/src/transformers/models/deberta_v2/modeling_tf_deberta_v2.py @@ -313,7 +313,7 @@ def call( rmask = tf.cast(1 - input_mask, tf.bool) out = tf.where(tf.broadcast_to(tf.expand_dims(rmask, -1), shape_list(out)), 0.0, out) out = self.dropout(out, training=training) - hidden_states = self.conv_act(out) + out = self.conv_act(out) layer_norm_input = residual_states + out output = self.LayerNorm(layer_norm_input) @@ -323,10 +323,10 @@ def call( else: if len(shape_list(input_mask)) != len(shape_list(layer_norm_input)): if len(shape_list(input_mask)) == 4: - mask = tf.squeeze(tf.squeeze(input_mask, axis=1), axis=1) - mask = tf.cast(tf.expand_dims(input_mask, axis=2), tf.float32) + input_mask = tf.squeeze(tf.squeeze(input_mask, axis=1), axis=1) + input_mask = tf.cast(tf.expand_dims(input_mask, axis=2), tf.float32) - output_states = output * mask + output_states = output * input_mask return output_states