Skip to content

Commit

Permalink
fix static error in summary (#35303)
Browse files Browse the repository at this point in the history
  • Loading branch information
wangna11BD authored Sep 2, 2021
1 parent 25871e0 commit b28cc73
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 0 deletions.
2 changes: 2 additions & 0 deletions python/paddle/hapi/model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ def forward(self, inputs):
input_size = []
for key in input.keys():
input_size.append(tuple(input[key].shape))
elif isinstance(input, paddle.fluid.framework.Variable):
input_size = tuple(input.shape)
else:
raise ValueError(
"Input is not tensor, list, tuple and dict, unable to determine input_size, please input input_size."
Expand Down
6 changes: 6 additions & 0 deletions python/paddle/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,12 @@ def _get_param_from_state_dict(state_dict):
np.testing.assert_allclose(params_info['total_params'], gt_params / 2.0)

def test_summary_input(self):
paddle.enable_static()
mymodel = MyModel()
input_data = paddle.rand([1, 20])
paddle.summary(mymodel, input=input_data)
paddle.disable_static()

rnn = paddle.nn.SimpleRNN(16, 32, 2, direction='bidirectional')
input_data = paddle.rand([4, 23, 16])
paddle.summary(rnn, input=input_data)
Expand Down

0 comments on commit b28cc73

Please sign in to comment.