Skip to content

Commit

Permalink
[Dy2St][PIR] Add restore_out in PIR sot_call (#63190)
Browse files Browse the repository at this point in the history
  • Loading branch information
SigureMo authored Apr 2, 2024
1 parent 5dfe454 commit a1f5cdb
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 7 deletions.
3 changes: 2 additions & 1 deletion python/paddle/jit/dy2static/pir_partial_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,8 @@ def sot_call(self, inputs):
self._cuda_graph_vec,
*attrs,
)
return out_vars
restored_nest_out = self._restore_out(out_vars)
return restored_nest_out

@cached_property
def origin_runnable_program(self):
Expand Down
7 changes: 6 additions & 1 deletion python/paddle/jit/sot/infer_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,11 +266,16 @@ def infer_meta_for_layer(layer, *args, **kwargs):
partial_program_layer,
) = layer.forward.get_concrete_program(*args_, **kwargs_)

if use_pir_api():
output_values = partial_program_layer._outputs.var_list
else:
output_values = concrete_program.outputs

out = partial_program_layer._restore_out(
[
x
for x in paddle.utils.flatten(
convert_variable_to_meta_info(concrete_program.outputs)
convert_variable_to_meta_info(output_values)
)
if isinstance(x, MetaInfo)
]
Expand Down
30 changes: 25 additions & 5 deletions test/dygraph_to_static/test_duplicate_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,6 @@

np.random.seed(1)

if paddle.base.is_compiled_with_cuda():
place = paddle.base.CUDAPlace(0)
else:
place = paddle.base.CPUPlace()


class SimpleNet(paddle.nn.Layer):
def __init__(self):
Expand All @@ -41,6 +36,17 @@ def forward(self, x):
return x, x


class DuplicateOutputInPaddleLayer(paddle.nn.Layer):
def __init__(self):
super().__init__()
# In GRUCell, the output is a tuple (h, h)
self.layer = paddle.nn.GRUCell(10, 20)

def forward(self, x):
x = self.layer(x)
return x


class TestDuplicateOutput(Dy2StTestBase):
def _run_static(self):
net = paddle.jit.to_static(SimpleNet())
Expand All @@ -58,5 +64,19 @@ def test_ast_to_func(self):
self._run_static()


class TestDuplicateOutputInPaddleLayer(Dy2StTestBase):
def check_dygraph_and_static_result(self, net, x):
static_net = paddle.jit.to_static(net)
dy_out = net(x)
st_out = static_net(x)
np.testing.assert_allclose(dy_out, st_out)

@test_legacy_and_pt_and_pir
def test_ast_to_func(self):
net = DuplicateOutputInPaddleLayer()
x = paddle.randn([10, 10])
self.check_dygraph_and_static_result(net, x)


if __name__ == '__main__':
unittest.main()

0 comments on commit a1f5cdb

Please sign in to comment.