-
Notifications
You must be signed in to change notification settings - Fork 6.8k
add symbol.SwapAxis operator, just can do Forward(). #502
Changes from 8 commits
4d749bb
ab19122
0ca99da
1dd6ada
1a2dc81
a0ec24d
23cfadc
fcd62fd
a21d3be
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
+3 −3 | doc/README.md | |
+2 −0 | guide/basic.cpp | |
+1 −0 | mshadow/base.h | |
+7 −7 | mshadow/cuda/tensor_gpu-inl.cuh | |
+57 −3 | mshadow/dot_engine-inl.h | |
+1 −0 | mshadow/extension.h | |
+156 −0 | mshadow/extension/slice.h | |
+28 −2 | mshadow/tensor.h | |
+13 −9 | mshadow/tensor_cpu-inl.h |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,206 @@ | ||
/*! | ||
* Copyright (c) 2015 by Contributors | ||
* \file swapaxis-inl.h | ||
* \brief | ||
* \author Ming Zhang | ||
*/ | ||
#ifndef MXNET_OPERATOR_SWAPAXIS_INL_H_ | ||
#define MXNET_OPERATOR_SWAPAXIS_INL_H_ | ||
|
||
#include <dmlc/logging.h> | ||
#include <dmlc/parameter.h> | ||
#include <mxnet/operator.h> | ||
#include <algorithm> | ||
#include <map> | ||
#include <vector> | ||
#include <string> | ||
#include <utility> | ||
#include "./operator_common.h" | ||
|
||
namespace mxnet { | ||
namespace op { | ||
|
||
struct SwapAxis{ | ||
enum SwapAxisOpInputs {kData}; | ||
enum SwapAxisOpOutputs {kOut}; | ||
}; | ||
|
||
using namespace mshadow; | ||
using namespace mshadow::expr; | ||
|
||
struct SwapAxisParam : public dmlc::Parameter<SwapAxisParam> { | ||
// use int for enumeration | ||
uint32_t dim1, dim2; | ||
DMLC_DECLARE_PARAMETER(SwapAxisParam) { | ||
DMLC_DECLARE_FIELD(dim1) | ||
.set_default(0) | ||
.describe("the first axis to be swapped."); | ||
DMLC_DECLARE_FIELD(dim2) | ||
.set_default(0) | ||
.describe("the second axis to be swapped."); | ||
} | ||
}; | ||
|
||
|
||
template<typename xpu> | ||
class SwapAxisOp : public Operator { | ||
public: | ||
explicit SwapAxisOp(SwapAxisParam p) { | ||
CHECK_NE(p.dim1, p.dim2) << "dim1 can not be equal dim2."; | ||
this->param_ = p; | ||
} | ||
|
||
void Reshape2Five(Shape<5> *inter_shape, const TShape &shape, uint32_t dim1, uint32_t dim2) { | ||
index_t ndim_in = shape.ndim(); | ||
int si; | ||
|
||
if (dim1 > dim2) { | ||
std::swap(dim1, dim2); | ||
} | ||
|
||
for (si = 0; si < 5; si++) { | ||
(*inter_shape)[si] = 1; | ||
} | ||
// dim_0 | ||
for (si = 0; si < dim1; si++) { | ||
(*inter_shape)[0] *= shape[si]; | ||
} | ||
// dim_1 | ||
(*inter_shape)[1] = shape[dim1]; | ||
// dim_2 | ||
for (si = dim1 + 1; si < dim2; si++) { | ||
(*inter_shape)[2] *= shape[si]; | ||
} | ||
// dim_3 | ||
(*inter_shape)[3] = shape[dim2]; | ||
// dim_4 | ||
for (si = dim2 + 1; si < ndim_in; si++) { | ||
(*inter_shape)[4] *= shape[si]; | ||
} | ||
} | ||
|
||
void __swapaxis(Stream<xpu> *s, | ||
const std::vector<TBlob> &in_data, | ||
const std::vector<TBlob> &out_data) { | ||
uint32_t dim1 = param_.dim1; | ||
uint32_t dim2 = param_.dim2; | ||
|
||
TBlob data_in = in_data[SwapAxis::kData]; | ||
TBlob data_out = out_data[SwapAxis::kData]; | ||
|
||
TShape shape_in = data_in.shape_; | ||
TShape shape_out = data_out.shape_; | ||
|
||
Shape<5> inter_shape; | ||
|
||
Reshape2Five(&inter_shape, shape_in, dim1, dim2); | ||
|
||
Tensor<xpu, 5> inter_data_in = data_in.get_with_shape<xpu, 5, real_t>(inter_shape, s); | ||
|
||
Shape<5> inter_shape2 = inter_shape; | ||
std::swap(inter_shape2[1], inter_shape2[3]); | ||
|
||
Tensor<xpu, 5> inter_data_out = data_out.get_with_shape<xpu, 5, real_t>(inter_shape2, s); | ||
|
||
inter_data_out = swapaxis<3, 1>(inter_data_in); | ||
} | ||
|
||
virtual void Forward(const OpContext &ctx, | ||
const std::vector<TBlob> &in_data, | ||
const std::vector<OpReqType> &req, | ||
const std::vector<TBlob> &out_data, | ||
const std::vector<TBlob> &aux_args) { | ||
Stream<xpu> *s = ctx.get_stream<xpu>(); | ||
|
||
__swapaxis(s, in_data, out_data); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. rename _swapaxis to SwapAxis |
||
} | ||
|
||
virtual void Backward(const OpContext &ctx, | ||
const std::vector<TBlob> &out_grad, | ||
const std::vector<TBlob> &in_data, | ||
const std::vector<TBlob> &out_data, | ||
const std::vector<OpReqType> &req, | ||
const std::vector<TBlob> &in_grad, | ||
const std::vector<TBlob> &aux_args) { | ||
Stream<xpu> *s = ctx.get_stream<xpu>(); | ||
__swapaxis(s, out_grad, in_grad); | ||
} | ||
|
||
SwapAxisParam param_; | ||
}; | ||
|
||
|
||
template<typename xpu> | ||
Operator* CreateOp(SwapAxisParam param); | ||
|
||
|
||
#if DMLC_USE_CXX11 | ||
class SwapAxisProp : public OperatorProperty { | ||
public: | ||
std::vector<std::string> ListArguments() const override { | ||
return {"data"}; | ||
} | ||
|
||
void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override { | ||
param_.Init(kwargs); | ||
} | ||
|
||
std::map<std::string, std::string> GetParams() const override { | ||
return param_.__DICT__(); | ||
} | ||
|
||
bool InferShape(std::vector<TShape> *in_shape, | ||
std::vector<TShape> *out_shape, | ||
std::vector<TShape> *aux_shape) const override { | ||
int input_num = in_shape->size(); | ||
if (input_num == 0) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. CHECK_EQ(input_shape->size(), 1); |
||
std::cout << "Have no input data.\n"; | ||
return false; | ||
} | ||
TShape &shape0 = (*in_shape)[SwapAxis::kData]; | ||
out_shape->clear(); | ||
out_shape->push_back(shape0); | ||
TShape &shape1 = (*out_shape)[SwapAxis::kOut]; | ||
|
||
std::swap(shape1[param_.dim1], shape1[param_.dim2]); | ||
|
||
return true; | ||
} | ||
|
||
OperatorProperty* Copy() const override { | ||
auto ptr = new SwapAxisProp(); | ||
ptr->param_ = param_; | ||
return ptr; | ||
} | ||
|
||
std::string TypeString() const override { | ||
return "SwapAxis"; | ||
} | ||
|
||
std::vector<int> DeclareBackwardDependency( | ||
const std::vector<int> &out_grad, | ||
const std::vector<int> &in_data, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. override DeclarebackwardDependency
|
||
const std::vector<int> &out_data) const override { | ||
return {out_grad[SwapAxis::kOut]}; | ||
}; | ||
/* | ||
std::vector<ResourceRequest> ForwardResource( | ||
const std::vector<TShape> &in_shape) const override; | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove the commented fields |
||
std::vector<ResourceRequest> BackwardResource( | ||
const std::vector<TShape> &in_shape) const override; | ||
*/ | ||
Operator* CreateOperator(Context ctx) const override; | ||
|
||
private: | ||
SwapAxisParam param_; | ||
}; // class SwapAxisProp | ||
#endif // DMLC_USE_CXX11 | ||
|
||
|
||
} // namespace op | ||
} // namespace mxnet | ||
|
||
#endif // MXNET_OPERATOR_SWAPAXIS_INL_H_ | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
/*! | ||
* Copyright (c) 2015 by Contributors | ||
* \file swapaxis.cc | ||
* \brief | ||
* \author Ming Zhang | ||
*/ | ||
|
||
#include "./swapaxis-inl.h" | ||
|
||
namespace mxnet { | ||
namespace op { | ||
|
||
template<> | ||
Operator* CreateOp<cpu>(SwapAxisParam param) { | ||
return new SwapAxisOp<cpu>(param); | ||
} | ||
|
||
Operator* SwapAxisProp::CreateOperator(Context ctx) const { | ||
DO_BIND_DISPATCH(CreateOp, param_); | ||
} | ||
|
||
|
||
DMLC_REGISTER_PARAMETER(SwapAxisParam); | ||
|
||
MXNET_REGISTER_OP_PROPERTY(SwapAxis, SwapAxisProp) | ||
.add_argument("data", "Symbol", "Input data to the SwapAxisOp.") | ||
.add_arguments(SwapAxisParam::__FIELDS__()) | ||
.describe("Apply swapaxis to input."); | ||
} // namespace op | ||
} // namespace mxnet |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
/*! | ||
* Copyright (c) 2015 by Contributors | ||
* \file swapaxis.cu | ||
* \brief | ||
* \author Ming Zhang | ||
*/ | ||
|
||
#include "./swapaxis-inl.h" | ||
|
||
namespace mxnet { | ||
namespace op { | ||
|
||
template<> | ||
Operator *CreateOp<gpu>(SwapAxisParam param) { | ||
return new SwapAxisOp<gpu>(param); | ||
} | ||
|
||
} // namespace op | ||
} // namespace mxnet | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -171,8 +171,29 @@ def test_regression(): | |
lambda x: x, | ||
lambda x, y : x - y) | ||
|
||
def test_swapaxes(): | ||
print 'test swapaxes...' | ||
data = mx.symbol.Variable('data') | ||
shape = (2, 3, 4) | ||
data_tmp = np.ones(shape) | ||
data_tmp[0] = 1 | ||
data_tmp[1] = 2 | ||
arr_data = mx.nd.array(data_tmp) | ||
swap0 = mx.symbol.SwapAxis(data=data, dim1=0, dim2=2) | ||
swap = mx.symbol.SwapAxis(data=swap0, dim1=1, dim2=2) | ||
exe_c = swap.bind(mx.cpu(), args=[arr_data]) | ||
exe_c.forward() | ||
out = exe_c.outputs[0].asnumpy() | ||
print data_tmp.shape | ||
print data_tmp | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do assert, and compare with numpy result, testcases should be able to fail when we run it and something wrong happens. We cannot rely on eyebow to detect errors. Something like
|
||
print out.shape | ||
print out | ||
print out.reshape((4, 6)) | ||
|
||
|
||
if __name__ == '__main__': | ||
test_elementwise_sum() | ||
test_concat() | ||
test_slice_channel() | ||
test_regression() | ||
test_swapaxes() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use Google code style SwapAxis