diff --git a/paddleformers/nn/lm_head.py b/paddleformers/nn/lm_head.py index 8b2f81cbd07..c88afa53914 100644 --- a/paddleformers/nn/lm_head.py +++ b/paddleformers/nn/lm_head.py @@ -93,7 +93,7 @@ def forward(self, hidden_states, tensor_parallel_output=None): hidden_states, self.weight, self.bias, - self.config.tie_word_embeddings, + True, ) return calc_lm_head_logits( diff --git a/tests/nn/test_lm_head.py b/tests/nn/test_lm_head.py index 781ac5fd6e8..eea4cac1399 100644 --- a/tests/nn/test_lm_head.py +++ b/tests/nn/test_lm_head.py @@ -60,7 +60,7 @@ def test_forward_fused_head_loss(self): self.assertEqual(output[0].shape, test_input.shape) self.assertEqual(output[1].shape, lm_head.weight.shape) self.assertIs(output[2], lm_head.bias) - self.assertEqual(output[3], config.tie_word_embeddings) + self.assertEqual(output[3], True) if __name__ == "__main__":