Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Move back operators to its original files
Browse files Browse the repository at this point in the history
  • Loading branch information
sandeep-krishnamurthy committed Jan 24, 2019
1 parent f7aa93d commit 149eb37
Show file tree
Hide file tree
Showing 5 changed files with 278 additions and 350 deletions.
214 changes: 214 additions & 0 deletions src/operator/image/image_random-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,220 @@ void ToTensor(const nnvm::NodeAttrs &attrs,
});
}

// Normalize Operator
// Parameter registration for image Normalize operator
struct NormalizeParam : public dmlc::Parameter<NormalizeParam> {
nnvm::Tuple<float> mean;
nnvm::Tuple<float> default_mean = {0.0f, 0.0f, 0.0f, 0.0f};
nnvm::Tuple<float> std;
nnvm::Tuple<float> default_std = {1.0f, 1.0f, 1.0f, 1.0f};

DMLC_DECLARE_PARAMETER(NormalizeParam) {
DMLC_DECLARE_FIELD(mean)
.set_default(default_mean)
.describe("Sequence of means for each channel. "
"Default value is 0.");
DMLC_DECLARE_FIELD(std)
.set_default(default_std)
.describe("Sequence of standard deviations for each channel. "
"Default value is 1.");
}
};

// Shape and Type inference for image Normalize operator

// Shape inference
inline bool NormalizeOpShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
const NormalizeParam &param = nnvm::get<NormalizeParam>(attrs.parsed);

const auto& dshape = (*in_attrs)[0];
if (!dshape.ndim()) return false;

CHECK((dshape.ndim() == 3) || (dshape.ndim() == 4))
<< "Input tensor must have shape (channels, height, width), or "
<< "(N, channels, height, width), but got " << dshape;

int32_t nchannels;
if (dshape.ndim() == 3) {
nchannels = dshape[0];
CHECK(nchannels == 3 || nchannels == 1)
<< "The first dimension of input tensor must be the channel dimension with "
<< "either 1 or 3 elements, but got input with shape " << dshape;
} else if (dshape.ndim() == 4) {
nchannels = dshape[1];
CHECK(nchannels == 3 || nchannels == 1)
<< "The second dimension of input tensor must be the channel dimension with "
<< "either 1 or 3 elements, but got input with shape " << dshape;
}

CHECK((param.mean.ndim() == 1) || (param.mean.ndim() == nchannels))
<< "Invalid mean for input with shape " << dshape
<< ". mean must have either 1 or " << nchannels
<< " elements, but got " << param.mean;
CHECK(param.std.ndim() == 1 || param.std.ndim() == nchannels)
<< "Invalid std for input with shape " << dshape
<< ". std must have either 1 or " << nchannels
<< " elements, but got " << param.std;

SHAPE_ASSIGN_CHECK(*out_attrs, 0, dshape);
return true;
}

// Type Inference
inline bool NormalizeOpType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);

// Normalized Tensor will be a float
TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kFloat32);
return out_attrs->at(0) != -1;
}

template<int req>
struct normalize_forward {
template<typename DType>
MSHADOW_XINLINE static void Map(int j, DType* out_data, const DType* in_data,
const int i, const int length, const int step,
const DType mean, const DType std_dev) {
KERNEL_ASSIGN(out_data[step + i*length + j], req,
(in_data[step + i*length + j] - mean) / std_dev);
}
};

template<typename xpu>
void NormalizeImpl(const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<TBlob> &outputs,
const std::vector<OpReqType> &req,
const NormalizeParam &param,
const int length,
const int channel,
const int step = 0) {
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();

MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
DType* input = inputs[0].dptr<DType>();
DType* output = outputs[0].dptr<DType>();

for (int i = 0; i < channel; ++i) {
DType mean = param.mean[param.mean.ndim() > 1 ? i : 0];
DType std_dev = param.std[param.std.ndim() > 1 ? i : 0];
mxnet_op::Kernel<normalize_forward<req_type>, xpu>::Launch(
s, length, output, input,
i, length, step, mean, std_dev);
}
});
});
}

template<typename xpu>
void NormalizeOpForward(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);

const NormalizeParam &param = nnvm::get<NormalizeParam>(attrs.parsed);

// 3D input (c, h, w)
if (inputs[0].ndim() == 3) {
const int length = inputs[0].shape_[1] * inputs[0].shape_[2];
const int channel = inputs[0].shape_[0];
NormalizeImpl<xpu>(ctx, inputs, outputs, req, param, length, channel);
} else if (inputs[0].ndim() == 4) {
// 4D input (n, c, h, w)
const int batch_size = inputs[0].shape_[0];
const int length = inputs[0].shape_[2] * inputs[0].shape_[3];
const int channel = inputs[0].shape_[1];
const int step = channel * length;

#pragma omp parallel for
for (auto n = 0; n < batch_size; ++n) {
NormalizeImpl<xpu>(ctx, inputs, outputs, req, param, length, channel, n*step);
}
}
}

// Backward function
template<int req>
struct normalize_backward {
template<typename DType>
MSHADOW_XINLINE static void Map(int j, DType* in_grad, const DType* out_grad,
const DType* in_data, const int i, const int length,
const int step, const DType std_dev) {
// d/dx{(x - mean) / std_dev} => (1 / std_dev)
KERNEL_ASSIGN(in_grad[step + i*length + j], req,
out_grad[step + i*length + j] * (1.0 / std_dev));
}
};

template<typename xpu>
void NormalizeBackwardImpl(const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<TBlob> &outputs,
const std::vector<OpReqType> &req,
const NormalizeParam &param,
const int length,
const int channel,
const int step = 0) {
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
const TBlob& out_grad = inputs[0];
const TBlob& in_data = inputs[1];
const TBlob& in_grad = outputs[0];
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
for (int i = 0; i < channel; ++i) {
DType std_dev = param.std[param.std.ndim() > 1 ? i : 0];
mxnet_op::Kernel<normalize_backward<req_type>, xpu>::Launch(
s, length, in_grad.dptr<DType>(), out_grad.dptr<DType>(),
in_data.dptr<DType>(), i, length, step, std_dev);
}
});
});
}

template<typename xpu>
void NormalizeOpBackward(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);

const NormalizeParam &param = nnvm::get<NormalizeParam>(attrs.parsed);

// Note: inputs[0] is out_grad
const TBlob& in_data = inputs[1];

// 3D input (c, h, w)
if (in_data.ndim() == 3) {
const int length = in_data.shape_[1] * in_data.shape_[2];
const int channel = in_data.shape_[0];
NormalizeBackwardImpl<xpu>(ctx, inputs, outputs, req, param, length, channel);
} else if (in_data.ndim() == 4) {
// 4D input (n, c, h, w)
const int batch_size = in_data.shape_[0];
const int length = in_data.shape_[2] * in_data.shape_[3];
const int channel = in_data.shape_[1];
const int step = channel * length;

#pragma omp parallel for
for (auto n = 0; n < batch_size; ++n) {
NormalizeBackwardImpl<xpu>(ctx, inputs, outputs, req, param, length, channel, n*step);
}
}
}

template<typename DType>
inline DType saturate_cast(const float& src) {
return static_cast<DType>(src);
Expand Down
60 changes: 60 additions & 0 deletions src/operator/image/image_random.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ namespace mxnet {
namespace op {
namespace image {

DMLC_REGISTER_PARAMETER(NormalizeParam);
DMLC_REGISTER_PARAMETER(RandomEnhanceParam);
DMLC_REGISTER_PARAMETER(AdjustLightingParam);
DMLC_REGISTER_PARAMETER(RandomLightingParam);
Expand All @@ -47,6 +48,65 @@ NNVM_REGISTER_OP(_image_to_tensor)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{ "_copy" })
.add_argument("data", "NDArray-or-Symbol", "The input.");

NNVM_REGISTER_OP(_image_normalize)
.describe(R"code(Normalize an tensor of shape (C x H x W) or (N x C x H x W) with mean and
standard deviation.
Given mean `(m1, ..., mn)` and std `(s\ :sub:`1`\ , ..., s\ :sub:`n`)` for `n` channels,
this transform normalizes each channel of the input tensor with:
.. math::
output[i] = (input[i] - m\ :sub:`i`\ ) / s\ :sub:`i`
If mean or std is scalar, the same value will be applied to all channels.
Default value for mean is 0.0 and stand deviation is 1.0.
Example:
.. code-block:: python
image = mx.nd.random.uniform(0, 1, (3, 4, 2))
normalize(image, mean=(0, 1, 2), std=(3, 2, 1))
[[[ 0.18293785 0.19761486]
[ 0.23839645 0.28142193]
[ 0.20092112 0.28598186]
[ 0.18162774 0.28241724]]
[[-0.2881726 -0.18821815]
[-0.17705294 -0.30780914]
[-0.2812064 -0.3512327 ]
[-0.05411351 -0.4716435 ]]
[[-1.0363373 -1.7273437 ]
[-1.6165586 -1.5223348 ]
[-1.208275 -1.1878313 ]
[-1.4711051 -1.5200229 ]]]
<NDArray 3x4x2 @cpu(0)>
)code" ADD_FILELINE)
.set_attr_parser(ParamParser<NormalizeParam>)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"data"};
})
.set_attr<nnvm::FInferShape>("FInferShape", NormalizeOpShape)
.set_attr<nnvm::FInferType>("FInferType", NormalizeOpType)
.set_attr<FCompute>("FCompute<cpu>", NormalizeOpForward<cpu>)
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
[](const NodeAttrs& attrs) {
return std::vector<std::pair<int, int> >{{0, 0}};
})
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{ "_backward_image_normalize"})
.add_argument("data", "NDArray-or-Symbol", "Input ndarray")
.add_arguments(NormalizeParam::__FIELDS__());

NNVM_REGISTER_OP(_backward_image_normalize)
.set_attr_parser(ParamParser<NormalizeParam>)
.set_num_inputs(2)
.set_num_outputs(1)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FCompute>("FCompute<cpu>", NormalizeOpBackward<cpu>);

MXNET_REGISTER_IMAGE_AUG_OP(_image_flip_left_right)
.describe(R"code()code" ADD_FILELINE)
.set_attr<FCompute>("FCompute<cpu>", FlipLeftRight);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@
*/

/*!
* \file normalize_op.cu
* \brief GPU Implementation of Normalize op
* \file image_random.cu
* \brief GPU Implementation of image transformation operators
*/
#include "./normalize_op-inl.h"
#include "./image_random-inl.h"
#include "../elemwise_op_common.h"

namespace mxnet {
namespace op {
Expand Down
Loading

0 comments on commit 149eb37

Please sign in to comment.