diff --git a/tester/accuracy.py b/tester/accuracy.py index 6a24f96e..ea9b2038 100644 --- a/tester/accuracy.py +++ b/tester/accuracy.py @@ -377,6 +377,9 @@ def test(self): if self.api_config.api_name == "paddle.nn.utils.parameters_to_vector": paddle_out_grads = [] torch_out_grads = [] + if self.api_config.api_name == "paddle.nn.functional.kl_div": + paddle_out_grads = paddle_out_grads[:1] + torch_out_grads = torch_out_grads[:1] if self.api_config.api_name == "paddle.scale": paddle_out_grads = paddle_out_grads[0] torch_out_grads = torch_out_grads[0]