This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Fix softmax behavior to not cast up the accumulator if no output dtype is specified #14759
Closed
Closed
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
0901ca2
upcast Softmax Accumulator type only when output dtype is specified
nswamy 4754dc7
set softmax Accumulator type to the DType passed
nswamy 2f66745
Update test_softmax_dtype: use AType in np_softmax, change tolerance …
nswamy b3dc30d
changes to check_numeric_grad to support odtype and test_softmax_dtype
nswamy feda006
revert tests back to earlier since verifying against numpy on such a …
nswamy c47b024
don't remove MXNET_REAL_ACC_TYPE_SWITCH, since its used by other oper…
nswamy File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -64,7 +64,8 @@ struct log_softmax_fwd { | |
}; | ||
|
||
|
||
template<typename OP, bool negate, typename AType, typename DType, typename OType, int ndim> | ||
template<typename OP, bool negate, typename AType, typename DType, typename | ||
OType, int ndim> | ||
inline void Softmax(Stream<cpu> *s, DType *in, OType *out, | ||
Shape<ndim> shape, int axis, const DType temperature) { | ||
index_t M = shape[axis]; | ||
|
@@ -310,9 +311,9 @@ struct SoftmaxParam : public dmlc::Parameter<SoftmaxParam> { | |
} | ||
}; | ||
|
||
static inline bool softmax_has_dtype_override(const nnvm::NodeAttrs& attrs) { | ||
static inline int sofmtax_dtype_param(const nnvm::NodeAttrs &attrs) { | ||
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed); | ||
return param.dtype.has_value() && param.dtype.value() != -1; | ||
return param.dtype.has_value() ? param.dtype.value(): -1; | ||
} | ||
|
||
static inline bool SoftmaxOpType(const nnvm::NodeAttrs& attrs, | ||
|
@@ -322,7 +323,7 @@ static inline bool SoftmaxOpType(const nnvm::NodeAttrs& attrs, | |
CHECK_EQ(out_attrs->size(), 1); | ||
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed); | ||
|
||
if (softmax_has_dtype_override(attrs)) { | ||
if (sofmtax_dtype_param(attrs) != -1) { | ||
TYPE_ASSIGN_CHECK(*out_attrs, 0, param.dtype.value()); | ||
type_assign(&(*in_attrs)[0], (*out_attrs)[0]); | ||
return true; | ||
|
@@ -334,7 +335,7 @@ static inline bool SoftmaxOpType(const nnvm::NodeAttrs& attrs, | |
static inline bool SoftmaxGradOpShape(const nnvm::NodeAttrs& attrs, | ||
mxnet::ShapeVector *in_attrs, | ||
mxnet::ShapeVector *out_attrs) { | ||
if (softmax_has_dtype_override(attrs)) { | ||
if (sofmtax_dtype_param(attrs) != -1) { | ||
return ElemwiseShape<3, 1>(attrs, in_attrs, out_attrs); | ||
} else { | ||
return ElemwiseShape<2, 1>(attrs, in_attrs, out_attrs); | ||
|
@@ -345,7 +346,7 @@ static inline bool SoftmaxGradOpType(const nnvm::NodeAttrs& attrs, | |
std::vector<int>* in_attrs, | ||
std::vector<int>* out_attrs) { | ||
CHECK_EQ(out_attrs->size(), 1); | ||
if (softmax_has_dtype_override(attrs)) { | ||
if (sofmtax_dtype_param(attrs) != -1) { | ||
CHECK_EQ(in_attrs->size(), 3); | ||
int in_dtype = (*in_attrs)[1]; | ||
int out_dtype = (*in_attrs)[2]; | ||
|
@@ -365,19 +366,19 @@ static inline bool SoftmaxGradOpType(const nnvm::NodeAttrs& attrs, | |
|
||
static inline std::vector<std::pair<int, int> > | ||
SoftmaxGradOpInplaceOption(const nnvm::NodeAttrs& attrs) { | ||
if (softmax_has_dtype_override(attrs)) { | ||
if (sofmtax_dtype_param(attrs) != -1) { | ||
return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}, {2, 0}}; | ||
} else { | ||
return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}}; | ||
} | ||
} | ||
|
||
static inline uint32_t SoftmaxGradOpNumInputs(const nnvm::NodeAttrs& attrs) { | ||
return softmax_has_dtype_override(attrs) ? 3 : 2; | ||
return sofmtax_dtype_param(attrs) != -1 ? 3 : 2; | ||
} | ||
|
||
static inline std::vector<std::string> SoftmaxGradOpInputNames(const nnvm::NodeAttrs& attrs) { | ||
if (softmax_has_dtype_override(attrs)) { | ||
if (sofmtax_dtype_param(attrs) != -1) { | ||
return std::vector<std::string>{"ograd", "data", "output"}; | ||
} else { | ||
return std::vector<std::string>{"ograd", "output"}; | ||
|
@@ -388,7 +389,7 @@ struct SoftmaxFGradient { | |
const char *op_name; | ||
std::vector<nnvm::NodeEntry> operator()(const nnvm::NodePtr& n, | ||
const std::vector<nnvm::NodeEntry>& ograds) const { | ||
if (softmax_has_dtype_override(n->attrs)) { | ||
if (sofmtax_dtype_param(n->attrs) != -1) { | ||
return ElemwiseGradUseInOut {op_name}(n, ograds); | ||
} else { | ||
return ElemwiseGradUseOut {op_name}(n, ograds); | ||
|
@@ -410,19 +411,25 @@ void SoftmaxCompute(const nnvm::NodeAttrs& attrs, | |
const double temperature = param.temperature.has_value() ? | ||
param.temperature.value() : 1.0; | ||
mxnet::TShape shape = AxisShapeCompact(inputs[0].shape_, &axis, true); | ||
MXNET_REAL_ACC_TYPE_SWITCH(inputs[0].type_flag_, DType, AType, { | ||
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, { | ||
if (shape.ndim() == 2) { | ||
Softmax<OP, negate, AType>( | ||
ctx.get_stream<xpu>(), inputs[0].dptr<DType>(), | ||
outputs[0].dptr<OType>(), shape.get<2>(), axis, | ||
static_cast<DType>(temperature)); | ||
} else { | ||
Softmax<OP, negate, AType>( | ||
ctx.get_stream<xpu>(), inputs[0].dptr<DType>(), | ||
outputs[0].dptr<OType>(), shape.get<3>(), axis, | ||
static_cast<DType>(temperature)); | ||
} | ||
|
||
int atype_flag_ = sofmtax_dtype_param(attrs); | ||
atype_flag_ = atype_flag_ != -1 ? atype_flag_ : inputs[0].type_flag_; | ||
|
||
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { | ||
MSHADOW_REAL_TYPE_SWITCH(atype_flag_, AType, { | ||
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, { | ||
if (shape.ndim() == 2) { | ||
Softmax<OP, negate, AType>( | ||
ctx.get_stream<xpu>(), inputs[0].dptr<DType>(), | ||
outputs[0].dptr<OType>(), shape.get<2>(), axis, | ||
static_cast<DType>(temperature)); | ||
} else { | ||
Softmax<OP, negate, AType>( | ||
ctx.get_stream<xpu>(), inputs[0].dptr<DType>(), | ||
outputs[0].dptr<OType>(), shape.get<3>(), axis, | ||
static_cast<DType>(temperature)); | ||
} | ||
}); | ||
}); | ||
}); | ||
} | ||
|
@@ -442,23 +449,28 @@ void SoftmaxGradCompute(const nnvm::NodeAttrs& attrs, | |
param.temperature.value() : 1.0; | ||
mxnet::TShape shape = AxisShapeCompact(inputs[0].shape_, &axis, true); | ||
|
||
int out_idx = softmax_has_dtype_override(attrs) ? 2 : 1; | ||
|
||
MXNET_REAL_ACC_TYPE_SWITCH(inputs[0].type_flag_, OType, AType, { | ||
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { | ||
MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { | ||
if (shape.ndim() == 2) { | ||
SoftmaxGrad<OP1, OP2, Req, negate, AType>( | ||
ctx.get_stream<xpu>(), inputs[out_idx].dptr<OType>(), | ||
inputs[0].dptr<OType>(), outputs[0].dptr<DType>(), | ||
shape.get<2>(), axis, static_cast<DType>(temperature)); | ||
} else { | ||
SoftmaxGrad<OP1, OP2, Req, negate, AType>( | ||
ctx.get_stream<xpu>(), inputs[out_idx].dptr<OType>(), | ||
inputs[0].dptr<OType>(), outputs[0].dptr<DType>(), | ||
shape.get<3>(), axis, static_cast<DType>(temperature)); | ||
} | ||
}); | ||
int out_idx = sofmtax_dtype_param(attrs) != -1 ? 2 : 1; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd suggest we have an env var to trigger stable reduction instead of relying on dtype argument, including other ops like norm There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
||
int atype_flag_ = sofmtax_dtype_param(attrs); | ||
atype_flag_ = atype_flag_ != -1 ? atype_flag_ : inputs[0].type_flag_; | ||
|
||
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, OType, { | ||
MSHADOW_REAL_TYPE_SWITCH(atype_flag_, AType, { | ||
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { | ||
MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { | ||
if (shape.ndim() == 2) { | ||
SoftmaxGrad<OP1, OP2, Req, negate, AType>( | ||
ctx.get_stream<xpu>(), inputs[out_idx].dptr<OType>(), | ||
inputs[0].dptr<OType>(), outputs[0].dptr<DType>(), | ||
shape.get<2>(), axis, static_cast<DType>(temperature)); | ||
} else { | ||
SoftmaxGrad<OP1, OP2, Req, negate, AType>( | ||
ctx.get_stream<xpu>(), inputs[out_idx].dptr<OType>(), | ||
inputs[0].dptr<OType>(), outputs[0].dptr<DType>(), | ||
shape.get<3>(), axis, static_cast<DType>(temperature)); | ||
} | ||
}); | ||
}); | ||
}); | ||
}); | ||
} | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo? softmax