diff --git a/test/dygraph_to_static/test_slice.py b/test/dygraph_to_static/test_slice.py index 7a369a1ea265c..f26cdd5630e05 100644 --- a/test/dygraph_to_static/test_slice.py +++ b/test/dygraph_to_static/test_slice.py @@ -26,6 +26,7 @@ ) import paddle +from paddle.framework import use_pir_api from paddle.static import InputSpec SEED = 2020 @@ -178,6 +179,14 @@ def init_input(self): def init_dygraph_func(self): self.dygraph_func = test_set_value + # TODO(pir-control-flow): Delete this code after supporting control flow + @test_legacy_and_pt_and_pir + def test_transformed_static_result(self): + self.init_dygraph_func() + static_res = self.run_static_mode() + dygraph_res = self.run_dygraph_mode() + np.testing.assert_allclose(dygraph_res, static_res, rtol=1e-05) + class TestSetValueWithLayerAndSave(Dy2StTestBase): def setUp(self): @@ -190,18 +199,21 @@ def tearDown(self): self.temp_dir.cleanup() @test_ast_only + @test_legacy_and_pt_and_pir def test_set_value_with_save(self): with enable_to_static_guard(True): model = paddle.jit.to_static( LayerWithSetValue(input_dim=10, hidden=1) ) x = paddle.full(shape=[5, 10], fill_value=5.0, dtype="float32") - paddle.jit.save( - layer=model, - path=self.model_path, - input_spec=[x], - output_spec=None, - ) + # TODO(pir-save-load): Fix this after we support save/load in PIR + if not use_pir_api(): + paddle.jit.save( + layer=model, + path=self.model_path, + input_spec=[x], + output_spec=None, + ) class TestSliceSupplementSpecialCase(Dy2StTestBase):