Skip to content

Commit

Permalink
Fix rnn grad bug in cpu when dropout is zero (#37080)
Browse files Browse the repository at this point in the history
* fix rnn grad bug when num_layers is set 2 and dropout_prob is set 0

* add more test for rnn
  • Loading branch information
joey12300 committed Nov 10, 2021
1 parent a787b27 commit b5a0e53
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 0 deletions.
3 changes: 3 additions & 0 deletions paddle/fluid/operators/rnn_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -965,6 +965,9 @@ class RNNCPUKernel : public framework::OpKernel<T> {
}
dropout_mask->mutable_data<uint8_t>(output->dims(), ctx.GetPlace());

auto& dev_ctx = ctx.template device_context<platform::CPUDeviceContext>();
math::SetConstant<platform::CPUDeviceContext, uint8_t> ones;
ones(dev_ctx, dropout_mask, static_cast<uint8_t>(1));
// init the output and allocate the memory
output->mutable_data<T>(ctx.GetPlace());
int gate_num = 4;
Expand Down
30 changes: 30 additions & 0 deletions python/paddle/fluid/tests/unittests/test_rnn_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,5 +168,35 @@ def set_attrs(self):
self.is_bidirec = True


class TestRNNOp5(TestRNNOp):
def set_attrs(self):
self.num_layers = 2


class TestRNNOp6(TestRNNOp):
def set_attrs(self):
self.num_layers = 2
self.is_bidirec = True


class TestRNNOp7(TestRNNOp):
def set_attrs(self):
self.num_layers = 2
self.is_bidirec = True
self.is_test = True


class TestRNNOp8(TestRNNOp):
def set_attrs(self):
self.num_layers = 2
self.is_bidirec = True
self.sequence_length = None


class TestRNNOp9(TestRNNOp):
def set_attrs(self):
self.num_layers = 3


if __name__ == '__main__':
unittest.main()

1 comment on commit b5a0e53

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.