-
Notifications
You must be signed in to change notification settings - Fork 6.8k
add symbol.SwapAxis operator, just can do Forward(). #502
Changes from 6 commits
4d749bb
ab19122
0ca99da
1dd6ada
1a2dc81
a0ec24d
23cfadc
fcd62fd
a21d3be
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -38,15 +38,15 @@ ADD_CFLAGS = | |
#--------------------------------------------- | ||
|
||
# whether use CUDA during compile | ||
USE_CUDA = 0 | ||
USE_CUDA = 1 | ||
|
||
# add the path to CUDA libary to link and compile flag | ||
# if you have already add them to enviroment variable, leave it as NONE | ||
# USE_CUDA_PATH = /usr/local/cuda | ||
USE_CUDA_PATH = NONE | ||
USE_CUDA_PATH = /usr/local/cuda | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please do not change config.mk, copy it to root, and change it on project root |
||
|
||
# whether use CUDNN R3 library | ||
USE_CUDNN = 0 | ||
USE_CUDNN = 1 | ||
|
||
# whether use opencv during compilation | ||
# you can disable it, however, you will not able to use | ||
|
+3 −3 | doc/README.md | |
+2 −0 | guide/basic.cpp | |
+1 −0 | mshadow/base.h | |
+7 −7 | mshadow/cuda/tensor_gpu-inl.cuh | |
+57 −3 | mshadow/dot_engine-inl.h | |
+1 −0 | mshadow/extension.h | |
+156 −0 | mshadow/extension/slice.h | |
+28 −2 | mshadow/tensor.h | |
+13 −9 | mshadow/tensor_cpu-inl.h |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,252 @@ | ||
/*! | ||
* Copyright (c) 2015 by Contributors | ||
* \file swapaxis-inl.h | ||
* \brief | ||
* \author Ming Zhang | ||
*/ | ||
#ifndef MXNET_OPERATOR_SWAPAXIS_INL_H_ | ||
#define MXNET_OPERATOR_SWAPAXIS_INL_H_ | ||
|
||
#include <dmlc/logging.h> | ||
#include <dmlc/parameter.h> | ||
#include <mxnet/operator.h> | ||
#include <algorithm> | ||
#include <map> | ||
#include <vector> | ||
#include <string> | ||
#include <utility> | ||
#include "./operator_common.h" | ||
|
||
#define SWAPAXIS_DBG 0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove debug macro |
||
|
||
namespace mxnet { | ||
namespace op { | ||
|
||
struct SwapAxis{ | ||
enum SwapAxisOpInputs {kData}; | ||
enum SwapAxisOpOutputs {kOut}; | ||
}; | ||
|
||
using namespace mshadow; | ||
using namespace mshadow::expr; | ||
|
||
struct SwapAxisParam : public dmlc::Parameter<SwapAxisParam> { | ||
// use int for enumeration | ||
uint32_t dim1, dim2; | ||
DMLC_DECLARE_PARAMETER(SwapAxisParam) { | ||
DMLC_DECLARE_FIELD(dim1) | ||
.set_default(0) | ||
.describe("the first axis to be swapped."); | ||
DMLC_DECLARE_FIELD(dim2) | ||
.set_default(0) | ||
.describe("the second axis to be swapped."); | ||
} | ||
}; | ||
|
||
//a1 higher dimension to be swapped, assert a1 > a2 | ||
//a2 lower dimension to be swapped | ||
template<typename xpu> | ||
class SwapAxisOp : public Operator { | ||
public: | ||
explicit SwapAxisOp(SwapAxisParam p) { | ||
CHECK_LT(p.dim1, p.dim2) << "dim1 must be lower than dim2."; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. consider automatically swap if not the case |
||
#if SWAPAXIS_DBG | ||
printf("hello swapaxis SwapAxisOp:dim1:%d, dim2:%d!\n", p.dim1, p.dim2); | ||
#endif | ||
this->param_ = p; | ||
} | ||
|
||
|
||
void Reshape2Five(Shape<5> &inter_shape, TShape &shape, uint32_t dim1, uint32_t dim2) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use google C style pass non-const ref as pointer, and other as const reference |
||
{ | ||
int ndim_in = shape.ndim(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. index_t There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function can be static |
||
int si; | ||
for (si = 0; si < 5; si++) | ||
{ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. style: use for() { |
||
inter_shape[si] = 1; | ||
} | ||
//dim_0 | ||
for (si = 0; si < dim1; si++) | ||
{ | ||
inter_shape[0] *= shape[si]; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
} | ||
//dim_1 | ||
inter_shape[1] = shape[dim1]; | ||
//dim_2 | ||
for (si = dim1+1; si < dim2; si++) | ||
{ | ||
inter_shape[2] *= shape[si]; | ||
} | ||
//dim_3 | ||
inter_shape[3] = shape[dim2]; | ||
//dim_4 | ||
for (si = dim2+1; si < ndim_in; si++) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. space between dim2 and +1. |
||
{ | ||
inter_shape[4] *= shape[si]; | ||
} | ||
} | ||
|
||
void __swapaxis(Stream<xpu> *s, const std::vector<TBlob> &in_data, const std::vector<TBlob> &out_data) | ||
{ | ||
uint32_t dim1 = param_.dim1; | ||
uint32_t dim2 = param_.dim2; | ||
|
||
TBlob data_in = in_data[SwapAxis::kData]; | ||
TBlob data_out = out_data[SwapAxis::kData]; | ||
|
||
TShape shape_in = data_in.shape_; | ||
TShape shape_out = data_out.shape_; | ||
|
||
Shape<5> inter_shape; | ||
|
||
Reshape2Five(inter_shape, shape_in, dim1, dim2); | ||
|
||
Tensor<xpu, 5> inter_data_in = data_in.get_with_shape<xpu, 5, real_t>(inter_shape, s); | ||
|
||
int dwTmp = 0; | ||
|
||
Shape<5> inter_shape2 = inter_shape; | ||
inter_shape2[1] = inter_shape[3]; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use std::swap |
||
inter_shape2[3] = inter_shape[1]; | ||
|
||
Tensor<xpu, 5> inter_data_out = data_out.get_with_shape<xpu, 5, real_t>(inter_shape2, s); | ||
|
||
TShape shape_tmp = shape_in; | ||
dwTmp = shape_tmp[dim1]; | ||
shape_tmp[dim1] = shape_tmp[dim2]; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. std::swap(shape_tmp[dim1], shap_tmp[dim2]) |
||
shape_tmp[dim2] = dwTmp; | ||
|
||
CHECK(shape_out == shape_tmp); | ||
|
||
inter_data_out = swapaxis<3, 1>(inter_data_in); | ||
} | ||
|
||
virtual void Forward(const OpContext &ctx, | ||
const std::vector<TBlob> &in_data, | ||
const std::vector<OpReqType> &req, | ||
const std::vector<TBlob> &out_data, | ||
const std::vector<TBlob> &aux_args) { | ||
|
||
#if SWAPAXIS_DBG | ||
printf("hello swapaxis Forward!\n"); | ||
#endif | ||
Stream<xpu> *s = ctx.get_stream<xpu>(); | ||
|
||
__swapaxis(s, in_data, out_data); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use Google code style SwapAxis There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. rename _swapaxis to SwapAxis |
||
|
||
#if 0 | ||
delete []aDims_in; | ||
#endif | ||
} | ||
|
||
virtual void Backward(const OpContext &ctx, | ||
const std::vector<TBlob> &out_grad, | ||
const std::vector<TBlob> &in_data, | ||
const std::vector<TBlob> &out_data, | ||
const std::vector<OpReqType> &req, | ||
const std::vector<TBlob> &in_grad, | ||
const std::vector<TBlob> &aux_args) { | ||
#if SWAPAXIS_DBG | ||
printf("hello swapaxis Backward!\n"); | ||
#endif | ||
Stream<xpu> *s = ctx.get_stream<xpu>(); | ||
// __swapaxis(s, in_data, out_data); | ||
__swapaxis(s, out_grad, in_grad); | ||
} | ||
|
||
SwapAxisParam param_; | ||
}; | ||
|
||
|
||
template<typename xpu> | ||
Operator* CreateOp(SwapAxisParam param); | ||
|
||
|
||
#if DMLC_USE_CXX11 | ||
class SwapAxisProp : public OperatorProperty { | ||
public: | ||
std::vector<std::string> ListArguments() const override { | ||
return {"data"}; | ||
} | ||
|
||
void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override { | ||
param_.Init(kwargs); | ||
} | ||
|
||
std::map<std::string, std::string> GetParams() const override { | ||
return param_.__DICT__(); | ||
} | ||
|
||
bool InferShape(std::vector<TShape> *in_shape, | ||
std::vector<TShape> *out_shape, | ||
std::vector<TShape> *aux_shape) const override { | ||
int input_num = in_shape->size(); | ||
if (input_num == 0) | ||
{ | ||
std::cout << "Have no input data.\n"; | ||
return false; | ||
} | ||
TShape &shape0 = (*in_shape)[SwapAxis::kData]; | ||
#if SWAPAXIS_DBG | ||
printf("in_shape_num:%d\n", input_num); | ||
printf("in_shape_0, dim:%d, size:%d\n", (int)shape0.ndim(), (int)shape0.Size()); | ||
#endif | ||
if (shape0.ndim() != 4) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this restriction is no longer needed |
||
{ | ||
std::cout << "Input data should be 4D.\n"; | ||
return false; | ||
} | ||
out_shape->clear(); | ||
out_shape->push_back(shape0); | ||
TShape &shape1 = (*out_shape)[SwapAxis::kOut]; | ||
#if 1 | ||
int tmp = 0; | ||
tmp = shape1[param_.dim1]; | ||
shape1[param_.dim1] = shape1[param_.dim2]; | ||
shape1[param_.dim2] = tmp; | ||
#endif | ||
#if SWAPAXIS_DBG | ||
for (int i = 0; i < 4; i++) | ||
{ | ||
printf("%d[%d], ", shape1[i], shape0[i]); | ||
} | ||
printf("\n"); | ||
#endif | ||
return true; | ||
} | ||
|
||
OperatorProperty* Copy() const override { | ||
auto ptr = new SwapAxisProp(); | ||
ptr->param_ = param_; | ||
return ptr; | ||
} | ||
|
||
std::string TypeString() const override { | ||
return "SwapAxis"; | ||
} | ||
/* | ||
std::vector<int> DeclareBackwardDependency( | ||
const std::vector<int> &out_grad, | ||
const std::vector<int> &in_data, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. override DeclarebackwardDependency
|
||
const std::vector<int> &out_data) const override; | ||
|
||
std::vector<ResourceRequest> ForwardResource( | ||
const std::vector<TShape> &in_shape) const override; | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove the commented fields |
||
std::vector<ResourceRequest> BackwardResource( | ||
const std::vector<TShape> &in_shape) const override; | ||
*/ | ||
Operator* CreateOperator(Context ctx) const override; | ||
|
||
private: | ||
SwapAxisParam param_; | ||
}; // class SwapAxisProp | ||
#endif // DMLC_USE_CXX11 | ||
|
||
|
||
} | ||
} | ||
|
||
#endif | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
/*! | ||
* Copyright (c) 2015 by Contributors | ||
* \file swapaxis.cc | ||
* \brief | ||
* \author Ming Zhang | ||
*/ | ||
|
||
#include "./swapaxis-inl.h" | ||
|
||
namespace mxnet { | ||
namespace op { | ||
|
||
template<> | ||
Operator* CreateOp<cpu>(SwapAxisParam param) { | ||
return new SwapAxisOp<cpu>(param); | ||
} | ||
|
||
Operator* SwapAxisProp::CreateOperator(Context ctx) const { | ||
#if SWAPAXIS_DBG | ||
printf("hello swapaxis CreateOperator!\n"); | ||
#endif | ||
DO_BIND_DISPATCH(CreateOp, param_); | ||
} | ||
|
||
|
||
DMLC_REGISTER_PARAMETER(SwapAxisParam); | ||
|
||
MXNET_REGISTER_OP_PROPERTY(SwapAxis, SwapAxisProp) | ||
.add_argument("data", "Symbol", "Input data to the SwapAxisOp.") | ||
.add_arguments(SwapAxisParam::__FIELDS__()) | ||
.describe("Apply swapaxis to input."); | ||
|
||
|
||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
/*! | ||
* Copyright (c) 2015 by Contributors | ||
* \file swapaxis.cu | ||
* \brief | ||
* \author Ming Zhang | ||
*/ | ||
|
||
#include "./swapaxis-inl.h" | ||
|
||
namespace mxnet { | ||
namespace op { | ||
|
||
template<> | ||
Operator *CreateOp<gpu>(SwapAxisParam param) { | ||
return new SwapAxisOp<gpu>(param); | ||
} | ||
|
||
} // namespace op | ||
} // namespace mxnet | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please change this back to default, as most users don't have cuda