diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index a770902faa108..40fea46157f27 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -659,7 +659,7 @@ def binary_cross_entropy( % reduction ) - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): out = _C_ops.bce_loss(input, label) if weight is not None: out = _C_ops.multiply(out, weight, 'axis', -1) @@ -984,7 +984,7 @@ def hsigmoid_loss( if num_classes < 2: raise ValueError(f'Expected num_classes >= 2 (got {num_classes})') - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): out, _, _ = _C_ops.hsigmoid_loss( input, label, @@ -1103,7 +1103,7 @@ def smooth_l1_loss(input, label, reduction='mean', delta=1.0, name=None): """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): out = _C_ops.huber_loss(input, label, delta) else: check_variable_and_dtype( @@ -1329,7 +1329,7 @@ def l1_loss(input, label, reduction='mean', name=None): "received %s, which is not allowed." % reduction ) - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): unreduced = _C_ops.abs(_C_ops.subtract(input, label)) if reduction == 'mean': @@ -1688,7 +1688,7 @@ def kl_div(input, label, reduction='mean', name=None): ): label = paddle.cast(label, 'float64') - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): out = _C_ops.kldiv_loss(input, label, 'none') if reduction == 'mean': out = paddle.mean(out) diff --git a/test/legacy_test/test_bce_loss.py b/test/legacy_test/test_bce_loss.py index a9fe9cfa030d9..007bdffad0288 100644 --- a/test/legacy_test/test_bce_loss.py +++ b/test/legacy_test/test_bce_loss.py @@ -20,6 +20,7 @@ import paddle from paddle import base from paddle.base import core +from paddle.pir_utils import test_with_pir_api def test_static_layer( @@ -152,6 +153,7 @@ def calc_bceloss(input_np, label_np, reduction='mean', weight_np=None): class TestBCELoss(unittest.TestCase): + @test_with_pir_api def test_BCELoss(self): input_np = np.random.uniform(0.1, 0.8, size=(20, 30)).astype(np.float64) label_np = np.random.randint(0, 2, size=(20, 30)).astype(np.float64) @@ -185,6 +187,7 @@ def test_BCELoss(self): ) np.testing.assert_allclose(dy_functional, expected, rtol=1e-05) + @test_with_pir_api def test_BCELoss_weight(self): input_np = np.random.uniform(0.1, 0.8, size=(2, 3, 4, 10)).astype( np.float64 @@ -262,10 +265,10 @@ def setUp(self): self.outputs = {'Out': output_np} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_pir=True) def init_test_case(self): self.shape = [10, 10] @@ -286,17 +289,20 @@ def init_test_cast(self): class TestBceLossOpFP16(TestBceLossOp): def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_pir=True) def init_test_dtype(self): self.dtype = np.float16 class TestBceLossOpStaticFP16(unittest.TestCase): + @test_with_pir_api def test_fp16(self): + if not core.is_compiled_with_cuda(): + return paddle.enable_static() shape = [2, 3, 20] x_data = np.random.uniform(0.1, 0.8, shape).astype("float16") diff --git a/test/legacy_test/test_hsigmoid_op.py b/test/legacy_test/test_hsigmoid_op.py index 65cb8548e9eb8..9659b5e3b77d3 100644 --- a/test/legacy_test/test_hsigmoid_op.py +++ b/test/legacy_test/test_hsigmoid_op.py @@ -21,6 +21,7 @@ import paddle import paddle.nn.functional as F from paddle import base +from paddle.pir_utils import test_with_pir_api paddle.enable_static() np.random.seed(100) @@ -218,13 +219,14 @@ def setUp(self): self.user_grads = hsigmoid_grad(x, w, label, bias, num_classes) def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): self.check_grad( ['X', 'W', 'Bias'], ['Out'], user_defined_grads=self.user_grads, + check_pir=True, ) @@ -278,7 +280,7 @@ def setUp(self): self.outputs = {'PreOut': pre_output, 'Out': out} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) class TestHSigmoidOpWithSparseGrad(unittest.TestCase): @@ -323,9 +325,11 @@ def hs_net_conf(self, is_sparse): return avg_cost, data_list def training_test(self, is_sparse): - with base.program_guard(base.Program(), base.Program()): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): paddle.seed(1) - start_up = base.default_startup_program() + start_up = paddle.static.default_startup_program() x = np.arange(6).reshape(6) path_table = np.array([(1, 2, -1), (1, 2, -1)]).astype('int64') path_code = np.array([(1, 0, -1), (0, 0, -1)]).astype('int64') @@ -335,10 +339,10 @@ def training_test(self, is_sparse): optimizer = paddle.optimizer.SGD(learning_rate=1e-3) optimizer.minimize(loss) - main_program = base.default_main_program() + main_program = paddle.static.default_main_program() place = base.CPUPlace() feeder = base.DataFeeder(feed_list=data_list, place=place) - exe = base.Executor(place) + exe = paddle.static.Executor(place) exe.run(start_up) result = [] @@ -414,13 +418,14 @@ def setUp(self): self.outputs = {'PreOut': pre_output, 'Out': out} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): self.check_grad( ['Bias', 'X', 'W'], ['Out'], no_grad_set=set('Label'), + check_pir=True, ) @@ -479,10 +484,12 @@ def setUp(self): self.outputs = {'PreOut': pre_output, 'Out': out} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X', 'W'], ['Out'], no_grad_set=set('Label')) + self.check_grad( + ['X', 'W'], ['Out'], no_grad_set=set('Label'), check_pir=True + ) class TestHSigmoidLossAPI(unittest.TestCase): @@ -564,6 +571,7 @@ def test_dygraph_api(self): np.testing.assert_allclose(self.out_np, out.numpy(), rtol=1e-05) paddle.enable_static() + @test_with_pir_api def test_static_api(self): train_program = paddle.static.Program() startup_program = paddle.static.Program() @@ -619,10 +627,11 @@ def test_static_api(self): for ret in [ret1, ret2]: np.testing.assert_allclose(self.out_np, ret, rtol=1e-05) + @test_with_pir_api def test_base_api(self): - train_program = base.Program() - startup_program = base.Program() - with base.program_guard(train_program, startup_program): + train_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(train_program, startup_program): x = paddle.static.data('x', [-1, self.feature_size]) labels = paddle.static.data('labels', [-1, 1], 'int64') path_table = None @@ -647,7 +656,7 @@ def test_base_api(self): path_code=path_code, ) - exe = base.Executor(self.place) + exe = paddle.static.Executor(self.place) exe.run(startup_program) feed_dict = {'x': self.x_np, 'labels': self.labels_np} if self.is_custom: diff --git a/test/legacy_test/test_kldiv_loss_op.py b/test/legacy_test/test_kldiv_loss_op.py index ea93d0e4dd607..599b9764c984d 100644 --- a/test/legacy_test/test_kldiv_loss_op.py +++ b/test/legacy_test/test_kldiv_loss_op.py @@ -18,6 +18,7 @@ import paddle from paddle.nn.functional import kl_div +from paddle.pir_utils import test_with_pir_api def kldiv_loss(x, target, reduction): @@ -55,10 +56,10 @@ def setUp(self): self.outputs = {'Loss': loss.astype('float64')} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X'], 'Loss', no_grad_set={"Target"}) + self.check_grad(['X'], 'Loss', no_grad_set={"Target"}, check_pir=True) def initTestCase(self): self.x_shape = (4, 5, 5) @@ -111,6 +112,7 @@ def test_kl_loss_sum(self): def test_kl_loss_none(self): self.run_kl_loss('none') + @test_with_pir_api def test_kl_loss_static_api(self): with paddle_static_guard(): input = paddle.static.data(name='input', shape=[5, 20]) diff --git a/test/legacy_test/test_l1_loss.py b/test/legacy_test/test_l1_loss.py index 651d55977b34c..3a21e7ff97e48 100644 --- a/test/legacy_test/test_l1_loss.py +++ b/test/legacy_test/test_l1_loss.py @@ -18,6 +18,8 @@ import paddle from paddle import base +from paddle.framework import in_pir_mode +from paddle.pir_utils import test_with_pir_api class TestFunctionalL1Loss(unittest.TestCase): @@ -43,42 +45,48 @@ def run_imperative(self): np.testing.assert_allclose(dy_result.numpy(), expected, rtol=1e-05) self.assertEqual(dy_result.shape, [10, 10, 5]) + @test_with_pir_api def run_static(self, use_gpu=False): - input = paddle.static.data( - name='input', shape=[10, 10, 5], dtype='float32' - ) - label = paddle.static.data( - name='label', shape=[10, 10, 5], dtype='float32' - ) - result0 = paddle.nn.functional.l1_loss(input, label) - result1 = paddle.nn.functional.l1_loss(input, label, reduction='sum') - result2 = paddle.nn.functional.l1_loss(input, label, reduction='none') - y = paddle.nn.functional.l1_loss(input, label, name='aaa') - - place = base.CUDAPlace(0) if use_gpu else base.CPUPlace() - exe = base.Executor(place) - exe.run(base.default_startup_program()) - static_result = exe.run( - feed={"input": self.input_np, "label": self.label_np}, - fetch_list=[result0, result1, result2], - ) + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + input = paddle.static.data( + name='input', shape=[10, 10, 5], dtype='float32' + ) + label = paddle.static.data( + name='label', shape=[10, 10, 5], dtype='float32' + ) + result0 = paddle.nn.functional.l1_loss(input, label) + result1 = paddle.nn.functional.l1_loss( + input, label, reduction='sum' + ) + result2 = paddle.nn.functional.l1_loss( + input, label, reduction='none' + ) + y = paddle.nn.functional.l1_loss(input, label, name='aaa') - expected = np.mean(np.abs(self.input_np - self.label_np)) - np.testing.assert_allclose(static_result[0], expected, rtol=1e-05) - expected = np.sum(np.abs(self.input_np - self.label_np)) - np.testing.assert_allclose(static_result[1], expected, rtol=1e-05) - expected = np.abs(self.input_np - self.label_np) - np.testing.assert_allclose(static_result[2], expected, rtol=1e-05) + place = base.CUDAPlace(0) if use_gpu else base.CPUPlace() + exe = paddle.static.Executor(place) + static_result = exe.run( + feed={"input": self.input_np, "label": self.label_np}, + fetch_list=[result0, result1, result2], + ) - self.assertTrue('aaa' in y.name) + expected = np.mean(np.abs(self.input_np - self.label_np)) + np.testing.assert_allclose(static_result[0], expected, rtol=1e-05) + expected = np.sum(np.abs(self.input_np - self.label_np)) + np.testing.assert_allclose(static_result[1], expected, rtol=1e-05) + expected = np.abs(self.input_np - self.label_np) + np.testing.assert_allclose(static_result[2], expected, rtol=1e-05) + if not in_pir_mode(): + self.assertTrue('aaa' in y.name) def test_cpu(self): paddle.disable_static(place=paddle.base.CPUPlace()) self.run_imperative() paddle.enable_static() - with base.program_guard(base.Program()): - self.run_static() + self.run_static() def test_gpu(self): if not base.core.is_compiled_with_cuda(): @@ -88,11 +96,11 @@ def test_gpu(self): self.run_imperative() paddle.enable_static() - with base.program_guard(base.Program()): - self.run_static(use_gpu=True) + self.run_static(use_gpu=True) # test case the raise message def test_errors(self): + @test_with_pir_api def test_value_error(): input = paddle.static.data( name='input', shape=[10, 10, 5], dtype='float32' @@ -133,45 +141,49 @@ def run_imperative(self): np.testing.assert_allclose(dy_result.numpy(), expected, rtol=1e-05) self.assertEqual(dy_result.shape, [10, 10, 5]) + @test_with_pir_api def run_static(self, use_gpu=False): - input = paddle.static.data( - name='input', shape=[10, 10, 5], dtype='float32' - ) - label = paddle.static.data( - name='label', shape=[10, 10, 5], dtype='float32' - ) - l1_loss = paddle.nn.loss.L1Loss() - result0 = l1_loss(input, label) - l1_loss = paddle.nn.loss.L1Loss(reduction='sum') - result1 = l1_loss(input, label) - l1_loss = paddle.nn.loss.L1Loss(reduction='none') - result2 = l1_loss(input, label) - l1_loss = paddle.nn.loss.L1Loss(name='aaa') - result3 = l1_loss(input, label) - - place = base.CUDAPlace(0) if use_gpu else base.CPUPlace() - exe = base.Executor(place) - exe.run(base.default_startup_program()) - static_result = exe.run( - feed={"input": self.input_np, "label": self.label_np}, - fetch_list=[result0, result1, result2], - ) + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + input = paddle.static.data( + name='input', shape=[10, 10, 5], dtype='float32' + ) + label = paddle.static.data( + name='label', shape=[10, 10, 5], dtype='float32' + ) + l1_loss = paddle.nn.loss.L1Loss() + result0 = l1_loss(input, label) + l1_loss = paddle.nn.loss.L1Loss(reduction='sum') + result1 = l1_loss(input, label) + l1_loss = paddle.nn.loss.L1Loss(reduction='none') + result2 = l1_loss(input, label) + l1_loss = paddle.nn.loss.L1Loss(name='aaa') + result3 = l1_loss(input, label) + + place = base.CUDAPlace(0) if use_gpu else base.CPUPlace() + exe = paddle.static.Executor(place) + static_result = exe.run( + feed={"input": self.input_np, "label": self.label_np}, + fetch_list=[result0, result1, result2], + ) - expected = np.mean(np.abs(self.input_np - self.label_np)) - np.testing.assert_allclose(static_result[0], expected, rtol=1e-05) - expected = np.sum(np.abs(self.input_np - self.label_np)) - np.testing.assert_allclose(static_result[1], expected, rtol=1e-05) - expected = np.abs(self.input_np - self.label_np) - np.testing.assert_allclose(static_result[2], expected, rtol=1e-05) - self.assertTrue('aaa' in result3.name) + expected = np.mean(np.abs(self.input_np - self.label_np)) + np.testing.assert_allclose(static_result[0], expected, rtol=1e-05) + expected = np.sum(np.abs(self.input_np - self.label_np)) + np.testing.assert_allclose(static_result[1], expected, rtol=1e-05) + expected = np.abs(self.input_np - self.label_np) + np.testing.assert_allclose(static_result[2], expected, rtol=1e-05) + + if not in_pir_mode(): + self.assertTrue('aaa' in result3.name) def test_cpu(self): paddle.disable_static(place=paddle.base.CPUPlace()) self.run_imperative() paddle.enable_static() - with base.program_guard(base.Program()): - self.run_static() + self.run_static() def test_gpu(self): if not base.core.is_compiled_with_cuda(): @@ -181,11 +193,11 @@ def test_gpu(self): self.run_imperative() paddle.enable_static() - with base.program_guard(base.Program()): - self.run_static(use_gpu=True) + self.run_static(use_gpu=True) # test case the raise message def test_errors(self): + @test_with_pir_api def test_value_error(): loss = paddle.nn.loss.L1Loss(reduction="reduce_mean") diff --git a/test/legacy_test/test_smooth_l1_loss.py b/test/legacy_test/test_smooth_l1_loss.py index f070b747aeb5e..d9c1b3d4fcb13 100644 --- a/test/legacy_test/test_smooth_l1_loss.py +++ b/test/legacy_test/test_smooth_l1_loss.py @@ -18,6 +18,7 @@ import paddle from paddle import base +from paddle.pir_utils import test_with_pir_api def smooth_l1_loss_forward(val, delta): @@ -46,33 +47,40 @@ def setUp(self): def test_smooth_l1_loss_mean(self): input_np = np.random.random([100, 200]).astype(np.float32) label_np = np.random.random([100, 200]).astype(np.float32) - prog = base.Program() - startup_prog = base.Program() + place = ( base.CUDAPlace(0) if base.core.is_compiled_with_cuda() else base.CPUPlace() ) - with base.program_guard(prog, startup_prog): - input = paddle.static.data( - name='input', shape=[100, 200], dtype='float32' - ) - label = paddle.static.data( - name='label', shape=[100, 200], dtype='float32' - ) - smooth_l1_loss = paddle.nn.loss.SmoothL1Loss() - ret = smooth_l1_loss(input, label) - - exe = base.Executor(place) - (static_ret,) = exe.run( - prog, - feed={ - 'input': input_np, - 'label': label_np, - }, - fetch_list=[ret], - ) - self.assertIsNotNone(static_ret) + + expected = smooth_l1_loss_np(input_np, label_np, reduction='mean') + + @test_with_pir_api + def test_dynamic_or_pir_mode(): + prog = paddle.static.Program() + startup_prog = paddle.static.Program() + with paddle.static.program_guard(prog, startup_prog): + input = paddle.static.data( + name='input', shape=[100, 200], dtype='float32' + ) + label = paddle.static.data( + name='label', shape=[100, 200], dtype='float32' + ) + smooth_l1_loss = paddle.nn.loss.SmoothL1Loss() + ret = smooth_l1_loss(input, label) + + exe = paddle.static.Executor(place) + (static_ret,) = exe.run( + feed={ + 'input': input_np, + 'label': label_np, + }, + fetch_list=[ret], + ) + self.assertIsNotNone(static_ret) + np.testing.assert_allclose(static_ret, expected, rtol=1e-05) + with base.dygraph.guard(): smooth_l1_loss = paddle.nn.loss.SmoothL1Loss() dy_ret = smooth_l1_loss( @@ -81,41 +89,46 @@ def test_smooth_l1_loss_mean(self): ) dy_ret_value = dy_ret.numpy() self.assertIsNotNone(dy_ret_value) - expected = smooth_l1_loss_np(input_np, label_np, reduction='mean') - np.testing.assert_allclose(static_ret, dy_ret_value, rtol=1e-05) - np.testing.assert_allclose(static_ret, expected, rtol=1e-05) + + test_dynamic_or_pir_mode() np.testing.assert_allclose(dy_ret_value, expected, rtol=1e-05) def test_smooth_l1_loss_sum(self): input_np = np.random.random([100, 200]).astype(np.float32) label_np = np.random.random([100, 200]).astype(np.float32) - prog = base.Program() - startup_prog = base.Program() + place = ( base.CUDAPlace(0) if base.core.is_compiled_with_cuda() else base.CPUPlace() ) - with base.program_guard(prog, startup_prog): - input = paddle.static.data( - name='input', shape=[100, 200], dtype='float32' - ) - label = paddle.static.data( - name='label', shape=[100, 200], dtype='float32' - ) - smooth_l1_loss = paddle.nn.loss.SmoothL1Loss(reduction='sum') - ret = smooth_l1_loss(input, label) - - exe = base.Executor(place) - (static_ret,) = exe.run( - prog, - feed={ - 'input': input_np, - 'label': label_np, - }, - fetch_list=[ret], - ) - self.assertIsNotNone(static_ret) + expected = smooth_l1_loss_np(input_np, label_np, reduction='sum') + + @test_with_pir_api + def test_dynamic_or_pir_mode(): + prog = paddle.static.Program() + startup_prog = paddle.static.Program() + with paddle.static.program_guard(prog, startup_prog): + input = paddle.static.data( + name='input', shape=[100, 200], dtype='float32' + ) + label = paddle.static.data( + name='label', shape=[100, 200], dtype='float32' + ) + smooth_l1_loss = paddle.nn.loss.SmoothL1Loss(reduction='sum') + ret = smooth_l1_loss(input, label) + + exe = paddle.static.Executor(place) + (static_ret,) = exe.run( + feed={ + 'input': input_np, + 'label': label_np, + }, + fetch_list=[ret], + ) + self.assertIsNotNone(static_ret) + np.testing.assert_allclose(static_ret, expected, rtol=1e-05) + with base.dygraph.guard(): smooth_l1_loss = paddle.nn.loss.SmoothL1Loss(reduction='sum') dy_ret = smooth_l1_loss( @@ -124,41 +137,46 @@ def test_smooth_l1_loss_sum(self): ) dy_ret_value = dy_ret.numpy() self.assertIsNotNone(dy_ret_value) - expected = smooth_l1_loss_np(input_np, label_np, reduction='sum') - np.testing.assert_allclose(static_ret, dy_ret_value, rtol=1e-05) - np.testing.assert_allclose(static_ret, expected, rtol=1e-05) + + test_dynamic_or_pir_mode() np.testing.assert_allclose(dy_ret_value, expected, rtol=1e-05) def test_smooth_l1_loss_none(self): input_np = np.random.random([100, 200]).astype(np.float32) label_np = np.random.random([100, 200]).astype(np.float32) - prog = base.Program() - startup_prog = base.Program() + place = ( base.CUDAPlace(0) if base.core.is_compiled_with_cuda() else base.CPUPlace() ) - with base.program_guard(prog, startup_prog): - input = paddle.static.data( - name='input', shape=[100, 200], dtype='float32' - ) - label = paddle.static.data( - name='label', shape=[100, 200], dtype='float32' - ) - smooth_l1_loss = paddle.nn.loss.SmoothL1Loss(reduction='none') - ret = smooth_l1_loss(input, label) - - exe = base.Executor(place) - (static_ret,) = exe.run( - prog, - feed={ - 'input': input_np, - 'label': label_np, - }, - fetch_list=[ret], - ) - self.assertIsNotNone(static_ret) + expected = smooth_l1_loss_np(input_np, label_np, reduction='none') + + @test_with_pir_api + def test_dynamic_or_pir_mode(): + prog = paddle.static.Program() + startup_prog = paddle.static.Program() + with paddle.static.program_guard(prog, startup_prog): + input = paddle.static.data( + name='input', shape=[100, 200], dtype='float32' + ) + label = paddle.static.data( + name='label', shape=[100, 200], dtype='float32' + ) + smooth_l1_loss = paddle.nn.loss.SmoothL1Loss(reduction='none') + ret = smooth_l1_loss(input, label) + + exe = paddle.static.Executor(place) + (static_ret,) = exe.run( + feed={ + 'input': input_np, + 'label': label_np, + }, + fetch_list=[ret], + ) + self.assertIsNotNone(static_ret) + np.testing.assert_allclose(static_ret, expected, rtol=1e-05) + with base.dygraph.guard(): smooth_l1_loss = paddle.nn.loss.SmoothL1Loss(reduction='none') dy_ret = smooth_l1_loss( @@ -167,42 +185,47 @@ def test_smooth_l1_loss_none(self): ) dy_ret_value = dy_ret.numpy() self.assertIsNotNone(dy_ret_value) - expected = smooth_l1_loss_np(input_np, label_np, reduction='none') - np.testing.assert_allclose(static_ret, dy_ret_value, rtol=1e-05) - np.testing.assert_allclose(static_ret, expected, rtol=1e-05) + + test_dynamic_or_pir_mode() np.testing.assert_allclose(dy_ret_value, expected, rtol=1e-05) def test_smooth_l1_loss_delta(self): input_np = np.random.random([100, 200]).astype(np.float32) label_np = np.random.random([100, 200]).astype(np.float32) delta = np.random.rand() - prog = base.Program() - startup_prog = base.Program() + place = ( base.CUDAPlace(0) if base.core.is_compiled_with_cuda() else base.CPUPlace() ) - with base.program_guard(prog, startup_prog): - input = paddle.static.data( - name='input', shape=[100, 200], dtype='float32' - ) - label = paddle.static.data( - name='label', shape=[100, 200], dtype='float32' - ) - smooth_l1_loss = paddle.nn.loss.SmoothL1Loss(delta=delta) - ret = smooth_l1_loss(input, label) - - exe = base.Executor(place) - (static_ret,) = exe.run( - prog, - feed={ - 'input': input_np, - 'label': label_np, - }, - fetch_list=[ret], - ) - self.assertIsNotNone(static_ret) + expected = smooth_l1_loss_np(input_np, label_np, delta=delta) + + @test_with_pir_api + def test_dynamic_or_pir_mode(): + prog = paddle.static.Program() + startup_prog = paddle.static.Program() + with paddle.static.program_guard(prog, startup_prog): + input = paddle.static.data( + name='input', shape=[100, 200], dtype='float32' + ) + label = paddle.static.data( + name='label', shape=[100, 200], dtype='float32' + ) + smooth_l1_loss = paddle.nn.loss.SmoothL1Loss(delta=delta) + ret = smooth_l1_loss(input, label) + + exe = paddle.static.Executor(place) + (static_ret,) = exe.run( + feed={ + 'input': input_np, + 'label': label_np, + }, + fetch_list=[ret], + ) + self.assertIsNotNone(static_ret) + np.testing.assert_allclose(static_ret, expected, rtol=1e-05) + with base.dygraph.guard(): smooth_l1_loss = paddle.nn.loss.SmoothL1Loss(delta=delta) dy_ret = smooth_l1_loss( @@ -211,9 +234,8 @@ def test_smooth_l1_loss_delta(self): ) dy_ret_value = dy_ret.numpy() self.assertIsNotNone(dy_ret_value) - expected = smooth_l1_loss_np(input_np, label_np, delta=delta) - np.testing.assert_allclose(static_ret, dy_ret_value, rtol=1e-05) - np.testing.assert_allclose(static_ret, expected, rtol=1e-05) + + test_dynamic_or_pir_mode() np.testing.assert_allclose(dy_ret_value, expected, rtol=1e-05)