Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
szha committed Feb 11, 2019
1 parent eea52d3 commit 0c93cde
Show file tree
Hide file tree
Showing 4 changed files with 226 additions and 78 deletions.
2 changes: 1 addition & 1 deletion src/operator/mxnet_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ inline int get_num_threads<cpu>(const int N) {
case mshadow::kFloat32: \
{ \
typedef float DType; \
typedef float AType; \
typedef double AType; \
{__VA_ARGS__} \
} \
break; \
Expand Down
151 changes: 106 additions & 45 deletions src/operator/nn/softmax-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ struct log_softmax_fwd {
};


template<typename OP, bool negate, typename DType, typename AType, int ndim>
inline void Softmax(Stream<cpu> *s, DType *in, DType *out,
template<typename OP, bool negate, typename AType, typename DType, typename OType, int ndim>
inline void Softmax(Stream<cpu> *s, DType *in, OType *out,
Shape<ndim> shape, int axis, const DType temperature) {
index_t M = shape[axis];
index_t N = shape.Size()/M;
Expand All @@ -75,8 +75,7 @@ inline void Softmax(Stream<cpu> *s, DType *in, DType *out,

AType sum = AType(0);
DType in_val;
// By default temperature is 1.0, and only in reinforcement training
// users would set it to other values.
// By default temperature is 1.0.
// Adding a branch here to save the CPU 'divide-by-1' computation at runtime
if (temperature == 1.0) {
for (index_t j = 0; j < M; ++j) {
Expand Down Expand Up @@ -119,8 +118,9 @@ struct log_softmax_bwd {
};


template<typename OP1, typename OP2, int Req, bool negate, typename DType, typename AType, int ndim>
inline void SoftmaxGrad(Stream<cpu> *s, DType *out, DType *ograd,
template<typename OP1, typename OP2, int Req, bool negate,
typename AType, typename DType, typename OType, int ndim>
inline void SoftmaxGrad(Stream<cpu> *s, OType *out, OType *ograd,
DType *igrad, Shape<ndim> shape, int axis,
const DType temperature) {
index_t M = shape[axis];
Expand All @@ -139,8 +139,7 @@ inline void SoftmaxGrad(Stream<cpu> *s, DType *out, DType *ograd,
sum += OP1::Map(ograd[base + j*sa], out[base + j*sa]);
}

// By default temperature is 1.0, and only in reinforcement training
// users would set it to other values.
// By default temperature is 1.0.
// Adding a branch here to save the CPU 'divide-by-1' computation at runtime
DType final_result;
if (temperature == 1.0) {
Expand All @@ -163,8 +162,9 @@ inline void SoftmaxGrad(Stream<cpu> *s, DType *out, DType *ograd,


#ifdef __CUDACC__
template<int x_bits, typename OP, bool negate, typename DType, typename AType, int ndim>
__global__ void softmax_compute_kernel(DType *in, DType *out, index_t M, int axis,
template<int x_bits, typename OP, bool negate, typename AType, int ndim,
typename DType, typename OType>
__global__ void softmax_compute_kernel(DType *in, OType *out, index_t M, int axis,
Shape<ndim> sshape, Shape<ndim> stride,
const double temperature) {
const unsigned x_size = 1 << x_bits;
Expand Down Expand Up @@ -201,8 +201,8 @@ __global__ void softmax_compute_kernel(DType *in, DType *out, index_t M, int axi
}
}

template<typename OP, bool negate, typename DType, typename AType, int ndim>
inline void Softmax(Stream<gpu> *s, DType *in, DType *out,
template<typename OP, bool negate, typename AType, typename DType, typename OType, int ndim>
inline void Softmax(Stream<gpu> *s, DType *in, OType *out,
Shape<ndim> shape, int axis, const double temperature) {
const int x_bits = 7;
const int x_size = 1 << x_bits;
Expand All @@ -212,16 +212,16 @@ inline void Softmax(Stream<gpu> *s, DType *in, DType *out,
Shape<ndim> sshape = shape;
sshape[axis] = 1;

softmax_compute_kernel<x_bits, OP, negate, DType, AType, ndim>
softmax_compute_kernel<x_bits, OP, negate, AType, ndim>
<<<N, x_size, 0, mshadow::Stream<gpu>::GetStream(s)>>>(
in, out, M, axis, sshape, stride, temperature);
MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_compute_kernel);
}


template<int x_bits, typename OP1, typename OP2, int Req, bool negate,
typename DType, typename AType, int ndim>
__global__ void softmax_gradient_kernel(DType *out, DType *ograd, DType *igrad,
template<int x_bits, typename OP1, typename OP2, int Req, bool negate, typename AType, int ndim,
typename DType, typename OType>
__global__ void softmax_gradient_kernel(OType *out, OType *ograd, DType *igrad,
index_t M, int axis, Shape<ndim> sshape,
Shape<ndim> stride, const double temperature) {
const unsigned x_size = 1 << x_bits;
Expand Down Expand Up @@ -251,8 +251,9 @@ __global__ void softmax_gradient_kernel(DType *out, DType *ograd, DType *igrad,
}


template<typename OP1, typename OP2, int Req, bool negate, typename DType, typename AType, int ndim>
inline void SoftmaxGrad(Stream<gpu> *s, DType *out, DType *ograd,
template<typename OP1, typename OP2, int Req, bool negate, typename AType, int ndim,
typename DType, typename OType>
inline void SoftmaxGrad(Stream<gpu> *s, OType *out, OType *ograd,
DType *igrad, Shape<ndim> shape, int axis,
const double temperature) {
const int x_bits = 7;
Expand All @@ -263,7 +264,7 @@ inline void SoftmaxGrad(Stream<gpu> *s, DType *out, DType *ograd,
Shape<ndim> sshape = shape;
sshape[axis] = 1;

softmax_gradient_kernel<x_bits, OP1, OP2, Req, negate, DType, AType, ndim>
softmax_gradient_kernel<x_bits, OP1, OP2, Req, negate, AType, ndim>
<<<N, x_size, 0, mshadow::Stream<gpu>::GetStream(s)>>>(
out, ograd, igrad, M, axis, sshape, stride, temperature);
MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_gradient_kernel);
Expand All @@ -276,14 +277,70 @@ inline void SoftmaxGrad(Stream<gpu> *s, DType *out, DType *ograd,
struct SoftmaxParam : public dmlc::Parameter<SoftmaxParam> {
int axis;
dmlc::optional<double> temperature;
dmlc::optional<int> dtype;
DMLC_DECLARE_PARAMETER(SoftmaxParam) {
DMLC_DECLARE_FIELD(axis).set_default(-1)
.describe("The axis along which to compute softmax.");
.describe("The axis along which to compute softmax.");
DMLC_DECLARE_FIELD(temperature).set_default(dmlc::optional<double>())
.describe("Temperature parameter in softmax");
.describe("Temperature parameter in softmax");
DMLC_DECLARE_FIELD(dtype)
.add_enum("float16", mshadow::kFloat16)
.add_enum("float32", mshadow::kFloat32)
.add_enum("float64", mshadow::kFloat64)
.set_default(dmlc::optional<int>())
.describe("DType of the output in case this can't be inferred. "
"Defaults to the same as input's dtype if not defined (dtype=None).");
}
};

inline bool SoftmaxOpType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 1);
CHECK_EQ(out_attrs->size(), 1);
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);

int arg_dtype = param.dtype.has_value()?param.dtype.value():-1,
in_dtype = (*in_attrs)[0],
out_dtype = (*out_attrs)[0];

if (out_dtype != -1 && in_dtype != -1) {
TYPE_ASSIGN_CHECK(*out_attrs, 0, arg_dtype);
TYPE_ASSIGN_CHECK(*in_attrs, 0, out_dtype);
return true;
} else if (in_dtype != -1) {
if (arg_dtype != -1) {
TYPE_ASSIGN_CHECK(*out_attrs, 0, arg_dtype);
} else {
TYPE_ASSIGN_CHECK(*out_attrs, 0, in_dtype);
}
return true;
} else if (out_dtype != -1) {
TYPE_ASSIGN_CHECK(*out_attrs, 0, arg_dtype);
TYPE_ASSIGN_CHECK(*in_attrs, 0, out_dtype);
return true;
} else {
if (arg_dtype != -1) {
TYPE_ASSIGN_CHECK(*out_attrs, 0, arg_dtype);
}
return false;
}
}

inline bool SoftmaxGradOpType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 3);
CHECK_EQ(out_attrs->size(), 1);

int in_dtype = (*in_attrs)[1],
out_dtype = (*in_attrs)[2];
TYPE_ASSIGN_CHECK(*in_attrs, 0, out_dtype);
TYPE_ASSIGN_CHECK(*out_attrs, 0, in_dtype);

return (*out_attrs)[0] != -1 && (*in_attrs)[0] != -1;
}

template<typename xpu, typename OP, bool negate = false>
void SoftmaxCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
Expand All @@ -299,17 +356,19 @@ void SoftmaxCompute(const nnvm::NodeAttrs& attrs,
param.temperature.value() : 1.0;
TShape shape = AxisShapeCompact(inputs[0].shape_, &axis, true);
MXNET_REAL_ACC_TYPE_SWITCH(inputs[0].type_flag_, DType, AType, {
if (shape.ndim() == 2) {
Softmax<OP, negate, DType, AType>(
ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
outputs[0].dptr<DType>(), shape.get<2>(), axis,
static_cast<DType>(temperature));
} else {
Softmax<OP, negate, DType, AType>(
ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
outputs[0].dptr<DType>(), shape.get<3>(), axis,
static_cast<DType>(temperature));
}
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
if (shape.ndim() == 2) {
Softmax<OP, negate, AType>(
ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
outputs[0].dptr<OType>(), shape.get<2>(), axis,
static_cast<DType>(temperature));
} else {
Softmax<OP, negate, AType>(
ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
outputs[0].dptr<OType>(), shape.get<3>(), axis,
static_cast<DType>(temperature));
}
});
});
}

Expand All @@ -327,19 +386,21 @@ void SoftmaxGradCompute(const nnvm::NodeAttrs& attrs,
const double temperature = param.temperature.has_value() ?
param.temperature.value() : 1.0;
TShape shape = AxisShapeCompact(inputs[0].shape_, &axis, true);
MXNET_REAL_ACC_TYPE_SWITCH(inputs[0].type_flag_, DType, AType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
if (shape.ndim() == 2) {
SoftmaxGrad<OP1, OP2, Req, negate, DType, AType>(
ctx.get_stream<xpu>(), inputs[1].dptr<DType>(),
inputs[0].dptr<DType>(), outputs[0].dptr<DType>(),
shape.get<2>(), axis, static_cast<DType>(temperature));
} else {
SoftmaxGrad<OP1, OP2, Req, negate, DType, AType>(
ctx.get_stream<xpu>(), inputs[1].dptr<DType>(),
inputs[0].dptr<DType>(), outputs[0].dptr<DType>(),
shape.get<3>(), axis, static_cast<DType>(temperature));
}
MXNET_REAL_ACC_TYPE_SWITCH(inputs[2].type_flag_, OType, AType, {
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
if (shape.ndim() == 2) {
SoftmaxGrad<OP1, OP2, Req, negate, AType>(
ctx.get_stream<xpu>(), inputs[2].dptr<OType>(),
inputs[0].dptr<OType>(), outputs[0].dptr<DType>(),
shape.get<2>(), axis, static_cast<DType>(temperature));
} else {
SoftmaxGrad<OP1, OP2, Req, negate, AType>(
ctx.get_stream<xpu>(), inputs[2].dptr<OType>(),
inputs[0].dptr<OType>(), outputs[0].dptr<DType>(),
shape.get<3>(), axis, static_cast<DType>(temperature));
}
});
});
});
}
Expand Down
90 changes: 81 additions & 9 deletions src/operator/nn/softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ inline static bool SoftmaxStorageType(const nnvm::NodeAttrs& attrs,
}
#endif

MXNET_OPERATOR_REGISTER_UNARY(softmax)
NNVM_REGISTER_OP(softmax)
.describe(R"code(Applies the softmax function.
The resulting array contains elements in the range (0,1) and the elements along the given axis sum up to 1.
Expand Down Expand Up @@ -102,15 +102,39 @@ Example::
.set_attr<FComputeEx>("FComputeEx<cpu>", SoftmaxComputeExCPU)
.set_attr<FInferStorageType>("FInferStorageType", SoftmaxStorageType)
#endif
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseOut{"_backward_softmax"})
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseInOut{"_backward_softmax"})
.set_attr<nnvm::FInferType>("FInferType", SoftmaxOpType)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
[](const NodeAttrs& attrs){
return std::vector<std::pair<int, int> >{{0, 0}};
})
.add_argument("data", "NDArray-or-Symbol", "The input array.")
.add_arguments(SoftmaxParam::__FIELDS__());

MXNET_OPERATOR_REGISTER_BINARY(_backward_softmax)
NNVM_REGISTER_OP(_backward_softmax)
.set_num_inputs(3)
.set_num_outputs(1)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"ograd", "data", "output"};
})
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<3, 1>)
.set_attr<nnvm::FInferType>("FInferType", SoftmaxGradOpType)
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
[](const NodeAttrs& attrs){
return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}, {2, 0}};
})
.add_argument("ograd", "NDArray-or-Symbol", "gradient of output")
.add_argument("data", "NDArray-or-Symbol", "input")
.add_argument("output", "NDArray-or-Symbol", "output")
.set_attr_parser(ParamParser<SoftmaxParam>)
.set_attr<FCompute>("FCompute<cpu>", SoftmaxGradCompute<cpu, op::mshadow_op::mul,
mxnet_op::softmax_bwd>);

MXNET_OPERATOR_REGISTER_UNARY(softmin)
NNVM_REGISTER_OP(softmin)
.describe(R"code(Applies the softmin function.
The resulting array contains elements in the range (0,1) and the elements along the given axis sum
Expand Down Expand Up @@ -141,15 +165,39 @@ Example::
return std::vector<std::string>{"output"};
})
.set_attr<FCompute>("FCompute<cpu>", SoftmaxCompute<cpu, mxnet_op::softmax_fwd, true>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseOut{"_backward_softmin"})
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseInOut{"_backward_softmin"})
.set_attr<nnvm::FInferType>("FInferType", SoftmaxOpType)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
[](const NodeAttrs& attrs){
return std::vector<std::pair<int, int> >{{0, 0}};
})
.add_argument("data", "NDArray-or-Symbol", "The input array.")
.add_arguments(SoftmaxParam::__FIELDS__());

MXNET_OPERATOR_REGISTER_BINARY(_backward_softmin)
NNVM_REGISTER_OP(_backward_softmin)
.set_num_inputs(3)
.set_num_outputs(1)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"ograd", "data", "output"};
})
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<3, 1>)
.set_attr<nnvm::FInferType>("FInferType", SoftmaxGradOpType)
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
[](const NodeAttrs& attrs){
return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}, {2, 0}};
})
.add_argument("ograd", "NDArray-or-Symbol", "gradient of output")
.add_argument("data", "NDArray-or-Symbol", "input")
.add_argument("output", "NDArray-or-Symbol", "output")
.set_attr_parser(ParamParser<SoftmaxParam>)
.set_attr<FCompute>("FCompute<cpu>", SoftmaxGradCompute<cpu, op::mshadow_op::mul,
mxnet_op::softmax_bwd, true>);

MXNET_OPERATOR_REGISTER_UNARY(log_softmax)
NNVM_REGISTER_OP(log_softmax)
.describe(R"code(Computes the log softmax of the input.
This is equivalent to computing softmax followed by log.
Expand All @@ -168,10 +216,34 @@ Examples::
)code")
.set_attr_parser(ParamParser<SoftmaxParam>)
.set_attr<FCompute>("FCompute<cpu>", SoftmaxCompute<cpu, mxnet_op::log_softmax_fwd>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseOut{"_backward_log_softmax"})
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseInOut{"_backward_log_softmax"})
.set_attr<nnvm::FInferType>("FInferType", SoftmaxOpType)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
[](const NodeAttrs& attrs){
return std::vector<std::pair<int, int> >{{0, 0}};
})
.add_argument("data", "NDArray-or-Symbol", "The input array.")
.add_arguments(SoftmaxParam::__FIELDS__());

MXNET_OPERATOR_REGISTER_BINARY(_backward_log_softmax)
NNVM_REGISTER_OP(_backward_log_softmax)
.set_num_inputs(3)
.set_num_outputs(1)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"ograd", "data", "output"};
})
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<3, 1>)
.set_attr<nnvm::FInferType>("FInferType", SoftmaxGradOpType)
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
[](const NodeAttrs& attrs){
return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}, {2, 0}};
})
.add_argument("ograd", "NDArray-or-Symbol", "gradient of output")
.add_argument("data", "NDArray-or-Symbol", "input")
.add_argument("output", "NDArray-or-Symbol", "output")
.set_attr_parser(ParamParser<SoftmaxParam>)
.set_attr<FCompute>("FCompute<cpu>", SoftmaxGradCompute<cpu, mshadow_op::left,
mxnet_op::log_softmax_bwd>);
Expand Down
Loading

0 comments on commit 0c93cde

Please sign in to comment.