Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Move back changes to original image operators files
Browse files Browse the repository at this point in the history
  • Loading branch information
sandeep-krishnamurthy committed Jan 24, 2019
1 parent 691b77c commit 4ec2c1f
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 215 deletions.
103 changes: 102 additions & 1 deletion src/operator/image/image_random-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
#include <vector>
#include <cmath>
#include <limits>
#include <algorithm>
#include <utility>
#include "../mxnet_op.h"
#include "../operator_common.h"
Expand All @@ -40,6 +39,108 @@ namespace mxnet {
namespace op {
namespace image {

// There are no parameters for this operator.
// Hence, no arameter registration.

// Shape and Type inference for image to tensor operator
inline bool ToTensorShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);

TShape &shp = (*in_attrs)[0];
if (!shp.ndim()) return false;

CHECK((shp.ndim() == 3) || (shp.ndim() == 4))
<< "Input image must have shape (height, width, channels), or "
<< "(N, height, width, channels) but got " << shp;
if (shp.ndim() == 3) {
SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape({shp[2], shp[0], shp[1]}));
} else if (shp.ndim() == 4) {
SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape({shp[0], shp[3], shp[1], shp[2]}));
}

return true;
}

inline bool ToTensorType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kFloat32);
return (*in_attrs)[0] != -1;
}

// Operator Implementation

template<int req>
struct totensor_forward {
template<typename DType>
MSHADOW_XINLINE static void Map(int l, float* out_data, const DType* in_data,
const int c, const int length, const int channel,
const int step, const float normalize_factor = 255.0f) {
KERNEL_ASSIGN(out_data[step + c*length + l], req,
(in_data[step + l*channel + c]) / normalize_factor);
}
};

template<typename xpu>
void ToTensorImpl(const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<TBlob> &outputs,
const std::vector<OpReqType> &req,
const int length,
const int channel,
const int step = 0) {
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();

MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
float* output = outputs[0].dptr<float>();
DType* input = inputs[0].dptr<DType>();

for (int c = 0; c < channel; ++c) {
mxnet_op::Kernel<totensor_forward<req_type>, xpu>::Launch(
s, length, output, input, c, length, channel, step);
}
});
});
}

template<typename xpu>
void ToTensorOpForward(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);

CHECK_EQ(req[0], kWriteTo)
<< "`to_tensor` does not support inplace updates";

// 3D Input - (h, w, c)
if (inputs[0].ndim() == 3) {
const int length = inputs[0].shape_[0] * inputs[0].shape_[1];
const int channel = inputs[0].shape_[2];
ToTensorImpl<xpu>(ctx, inputs, outputs, req, length, channel);
} else if (inputs[0].ndim() == 4) {
// 4D input (n, h, w, c)
const int batch_size = inputs[0].shape_[0];
const int length = inputs[0].shape_[1] * inputs[0].shape_[2];
const int channel = inputs[0].shape_[3];
const int step = channel * length;

#pragma omp parallel for
for (auto n = 0; n < batch_size; ++n) {
ToTensorImpl<xpu>(ctx, inputs, outputs, req, length, channel, n*step);
}
}
}

struct NormalizeParam : public dmlc::Parameter<NormalizeParam> {
nnvm::Tuple<float> mean;
nnvm::Tuple<float> std;
Expand Down
39 changes: 39 additions & 0 deletions src/operator/image/image_random.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,45 @@ DMLC_REGISTER_PARAMETER(AdjustLightingParam);
DMLC_REGISTER_PARAMETER(RandomLightingParam);
DMLC_REGISTER_PARAMETER(RandomColorJitterParam);

NNVM_REGISTER_OP(_image_to_tensor)
.describe(R"code(Converts an image NDArray of shape (H x W x C) or (N x H x W x C)
with values in the range [0, 255] to a tensor NDArray of shape (C x H x W) or (N x C x H x W)
with values in the range [0, 1)
Example:
.. code-block:: python
image = mx.nd.random.uniform(0, 255, (4, 2, 3)).astype(dtype=np.uint8)
to_tensor(image)
[[[ 0.85490197 0.72156864]
[ 0.09019608 0.74117649]
[ 0.61960787 0.92941177]
[ 0.96470588 0.1882353 ]]
[[ 0.6156863 0.73725492]
[ 0.46666667 0.98039216]
[ 0.44705883 0.45490196]
[ 0.01960784 0.8509804 ]]
[[ 0.39607844 0.03137255]
[ 0.72156864 0.52941179]
[ 0.16470589 0.7647059 ]
[ 0.05490196 0.70588237]]]
<NDArray 3x4x2 @cpu(0)>
)code" ADD_FILELINE)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"data"};
})
.set_attr<nnvm::FInferShape>("FInferShape", ToTensorShape)
.set_attr<nnvm::FInferType>("FInferType", ToTensorType)
.set_attr<FCompute>("FCompute<cpu>", ToTensorOpForward<cpu>)
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
[](const NodeAttrs& attrs) {
return std::vector<std::pair<int, int> >{{0, 0}};
})
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{ "_copy" })
.add_argument("data", "NDArray-or-Symbol", "Input ndarray");

NNVM_REGISTER_OP(_image_normalize)
.describe(R"code()code" ADD_FILELINE)
.set_num_inputs(1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,13 @@
* specific language governing permissions and limitations
* under the License.
*/
#include "./totensor_op-inl.h"

/*!
* \file image_random.cu
* \brief GPU Implementation of image transformation operators
*/
#include "./image_random-inl.h"
#include "../elemwise_op_common.h"

namespace mxnet {
namespace op {
Expand Down
142 changes: 0 additions & 142 deletions src/operator/image/totensor_op-inl.h

This file was deleted.

71 changes: 0 additions & 71 deletions src/operator/image/totensor_op.cc

This file was deleted.

0 comments on commit 4ec2c1f

Please sign in to comment.