From 7e915506f960bac300a8407e2072953bb54b8f95 Mon Sep 17 00:00:00 2001 From: Kellen Sunderland Date: Wed, 19 Sep 2018 14:08:12 -0700 Subject: [PATCH] [MXNET-969] Fix buffer overflow in RNNOp Co-authored-by: Sina Md --- src/operator/rnn-inl.h | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h index 1f905eda4a92..4955c586f247 100644 --- a/src/operator/rnn-inl.h +++ b/src/operator/rnn-inl.h @@ -552,6 +552,13 @@ class RNNOp : public Operator{ } DType* reserve_space_ptr = static_cast(reserve_space_.dptr); + + int req_statecell = 0; + if (param_.mode == rnn_enum::kLstm) { + // State cell should be present for LSTMs. + req_statecell = req[rnn_enum::kStateCell]; + } + RNNBackward(workspace.dptr_, reserve_space_ptr, param_.num_layers, @@ -576,7 +583,7 @@ class RNNOp : public Operator{ req[rnn_enum::kData], req[rnn_enum::kParams], req[rnn_enum::kState], - req[rnn_enum::kStateCell], + req_statecell, param_.p, param_.mode); }