Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fix backward_clip num inputs and type of clip params #15688

Merged
merged 6 commits into from
Aug 9, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/operator/tensor/matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1457,7 +1457,7 @@ struct ClipParam : public dmlc::Parameter<ClipParam> {
struct clip {
template<typename DType>
MSHADOW_XINLINE static void Map(index_t i, DType* out, const DType* datas,
DType a_min, DType a_max) {
const float a_min, const float a_max) {
DType data = datas[i];
if (data > a_max) {
out[i] = a_max;
Expand All @@ -1473,7 +1473,7 @@ struct clip {
struct clip_grad {
template<typename DType>
MSHADOW_XINLINE static void Map(index_t i, DType* out, const DType* grad, const DType* datas,
DType a_min, DType a_max) {
const float a_min, const float a_max) {
DType data = datas[i];
if (data > a_max) {
out[i] = 0;
Expand All @@ -1500,7 +1500,7 @@ void Clip(const nnvm::NodeAttrs& attrs,
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
mxnet_op::Kernel<mxnet::op::clip, xpu>::Launch(s, outputs[0].Size(), outputs[0].dptr<DType>(),
inputs[0].dptr<DType>(),
DType(param.a_min), DType(param.a_max));
param.a_min, param.a_max);
});
}

Expand Down Expand Up @@ -1529,7 +1529,7 @@ void ClipGrad_(const nnvm::NodeAttrs& attrs,
Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
Kernel<clip_grad, xpu>::Launch(s, outputs[0].Size(), outputs[0].dptr<DType>(),
inputs[0].dptr<DType>(), inputs[1].dptr<DType>(), DType(param.a_min), DType(param.a_max));
inputs[0].dptr<DType>(), inputs[1].dptr<DType>(), param.a_min, param.a_max);
});
}

Expand Down
2 changes: 1 addition & 1 deletion src/operator/tensor/matrix_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -763,7 +763,7 @@ parameter values:
.add_arguments(ClipParam::__FIELDS__());

NNVM_REGISTER_OP(_backward_clip)
.set_num_inputs(1)
.set_num_inputs(2)
DickJC123 marked this conversation as resolved.
Show resolved Hide resolved
.set_num_outputs(1)
.set_attr_parser(ParamParser<ClipParam>)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
Expand Down