From 7e915506f960bac300a8407e2072953bb54b8f95 Mon Sep 17 00:00:00 2001 From: Kellen Sunderland Date: Wed, 19 Sep 2018 14:08:12 -0700 Subject: [PATCH 1/2] [MXNET-969] Fix buffer overflow in RNNOp Co-authored-by: Sina Md --- src/operator/rnn-inl.h | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h index 1f905eda4a92..4955c586f247 100644 --- a/src/operator/rnn-inl.h +++ b/src/operator/rnn-inl.h @@ -552,6 +552,13 @@ class RNNOp : public Operator{ } DType* reserve_space_ptr = static_cast(reserve_space_.dptr); + + int req_statecell = 0; + if (param_.mode == rnn_enum::kLstm) { + // State cell should be present for LSTMs. + req_statecell = req[rnn_enum::kStateCell]; + } + RNNBackward(workspace.dptr_, reserve_space_ptr, param_.num_layers, @@ -576,7 +583,7 @@ class RNNOp : public Operator{ req[rnn_enum::kData], req[rnn_enum::kParams], req[rnn_enum::kState], - req[rnn_enum::kStateCell], + req_statecell, param_.p, param_.mode); } From e89deca557a70e38e915fb911572290863a0f4fd Mon Sep 17 00:00:00 2001 From: Kellen Sunderland Date: Thu, 20 Sep 2018 14:15:44 -0700 Subject: [PATCH 2/2] [MXNET-969] Use ternary op for statecell --- src/operator/rnn-inl.h | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h index 4955c586f247..9211f6a456fe 100644 --- a/src/operator/rnn-inl.h +++ b/src/operator/rnn-inl.h @@ -552,13 +552,6 @@ class RNNOp : public Operator{ } DType* reserve_space_ptr = static_cast(reserve_space_.dptr); - - int req_statecell = 0; - if (param_.mode == rnn_enum::kLstm) { - // State cell should be present for LSTMs. - req_statecell = req[rnn_enum::kStateCell]; - } - RNNBackward(workspace.dptr_, reserve_space_ptr, param_.num_layers, @@ -583,7 +576,8 @@ class RNNOp : public Operator{ req[rnn_enum::kData], req[rnn_enum::kParams], req[rnn_enum::kState], - req_statecell, + // State cell should be present for LSTMs, but is absent for other RNNs. + param_.mode == rnn_enum::kLstm ? req[rnn_enum::kStateCell] : kNullOp, param_.p, param_.mode); }