Skip to content

Commit 0ce5d45

Browse files
committed
fix fused_head_and_loss_fn bug
1 parent 6a34ed4 commit 0ce5d45

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

tests/nn/test_lm_head.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def test_forward_fused_head_loss(self):
6060
self.assertEqual(output[0].shape, test_input.shape)
6161
self.assertEqual(output[1].shape, lm_head.weight.shape)
6262
self.assertIs(output[2], lm_head.bias)
63-
self.assertEqual(output[3], config.tie_word_embeddings)
63+
self.assertEqual(output[3], True)
6464

6565

6666
if __name__ == "__main__":

0 commit comments

Comments
 (0)