Skip to content

Commit

Permalink
[Dy2St][NO.14] pir dy2st unittest fix test_slice - Part 2 (#60200)
Browse files Browse the repository at this point in the history
  • Loading branch information
gouzil authored Dec 21, 2023
1 parent 5c4d98d commit e3b6061
Showing 1 changed file with 18 additions and 6 deletions.
24 changes: 18 additions & 6 deletions test/dygraph_to_static/test_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
)

import paddle
from paddle.framework import use_pir_api
from paddle.static import InputSpec

SEED = 2020
Expand Down Expand Up @@ -178,6 +179,14 @@ def init_input(self):
def init_dygraph_func(self):
self.dygraph_func = test_set_value

# TODO(pir-control-flow): Delete this code after supporting control flow
@test_legacy_and_pt_and_pir
def test_transformed_static_result(self):
self.init_dygraph_func()
static_res = self.run_static_mode()
dygraph_res = self.run_dygraph_mode()
np.testing.assert_allclose(dygraph_res, static_res, rtol=1e-05)


class TestSetValueWithLayerAndSave(Dy2StTestBase):
def setUp(self):
Expand All @@ -190,18 +199,21 @@ def tearDown(self):
self.temp_dir.cleanup()

@test_ast_only
@test_legacy_and_pt_and_pir
def test_set_value_with_save(self):
with enable_to_static_guard(True):
model = paddle.jit.to_static(
LayerWithSetValue(input_dim=10, hidden=1)
)
x = paddle.full(shape=[5, 10], fill_value=5.0, dtype="float32")
paddle.jit.save(
layer=model,
path=self.model_path,
input_spec=[x],
output_spec=None,
)
# TODO(pir-save-load): Fix this after we support save/load in PIR
if not use_pir_api():
paddle.jit.save(
layer=model,
path=self.model_path,
input_spec=[x],
output_spec=None,
)


class TestSliceSupplementSpecialCase(Dy2StTestBase):
Expand Down

0 comments on commit e3b6061

Please sign in to comment.