-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add GRU Operator #5255
Add GRU Operator #5255
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
"(string, default sigmoid) " | ||
"The activation type used in update gate and reset gate.") | ||
.SetDefault("sigmoid"); | ||
AddAttr<bool>("is_reverse", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
gate_activation -> gateActivation
is_reverse -> isReverse
paddle/operators/gru_op.cc
Outdated
ctx->SetOutputDim("BatchHidden", {input_dims[0], frame_size}); | ||
ctx->SetOutputDim("Hidden", {input_dims[0], frame_size}); | ||
// ctx->ShareLoD("Input", "Gate"); | ||
// ctx->ShareLoD("Input", "ResetHiddenPrev"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove these lines.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
paddle/operators/gru_op.cc
Outdated
GRUOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) | ||
: OpProtoAndCheckerMaker(proto, op_checker) { | ||
AddInput("Input", | ||
"(LoDTensor) The first input is a LodTensor, which support " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
support -> supports
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
paddle/operators/math/gru_compute.h
Outdated
// } else { | ||
// PADDLE_THROW("Do not support activation type."); | ||
// } | ||
// } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove these lines.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
|
||
|
||
def relu(x): | ||
return np.maximum(x, 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can these functions be imported from test_lstm_op.py ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
return batch_gate, batch_reset_hidden_prev, hidden | ||
|
||
def set_data(self): | ||
lod = [[0, 2, 6, 9]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lod = [[0, 2, 6, 9]] -> lod = [[0, 2, 6, batch_size]] ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
cd414bc
to
c4f7f3a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I approve this PR. the code style still needs to update later. And please add the TODO description before this PR is merged.
Resolves #5254
Rewrite GatedRecurrentLayer in the new framework. Some code styles haven't formatted.